MLIR 23.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/SaveAndRestore.h"
26#include "llvm/Support/ScopedPrinter.h"
27#include <optional>
28#include <utility>
29
30using namespace mlir;
31using namespace mlir::detail;
32
33#define DEBUG_TYPE "dialect-conversion"
34
35/// A utility function to log a successful result for the given reason.
36template <typename... Args>
37static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
38 LLVM_DEBUG({
39 os.unindent();
40 os.startLine() << "} -> SUCCESS";
41 if (!fmt.empty())
42 os.getOStream() << " : "
43 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
44 os.getOStream() << "\n";
45 });
46}
47
48/// A utility function to log a failure result for the given reason.
49template <typename... Args>
50static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
51 LLVM_DEBUG({
52 os.unindent();
53 os.startLine() << "} -> FAILURE : "
54 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
55 << "\n";
56 });
57}
58
59/// Helper function that computes an insertion point where the given value is
60/// defined and can be used without a dominance violation.
62 Block *insertBlock = value.getParentBlock();
63 Block::iterator insertPt = insertBlock->begin();
64 if (OpResult inputRes = dyn_cast<OpResult>(value))
65 insertPt = ++inputRes.getOwner()->getIterator();
66 return OpBuilder::InsertPoint(insertBlock, insertPt);
67}
68
69/// Helper function that computes an insertion point where the given values are
70/// defined and can be used without a dominance violation.
72 assert(!vals.empty() && "expected at least one value");
73 DominanceInfo domInfo;
75 for (Value v : vals.drop_front()) {
76 // Choose the "later" insertion point.
78 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
79 nextPt.getPoint())) {
80 // pt is before nextPt => choose nextPt.
81 pt = nextPt;
82 } else {
83#ifndef NDEBUG
84 // nextPt should be before pt => choose pt.
85 // If pt, nextPt are no dominance relationship, then there is no valid
86 // insertion point at which all given values are defined.
87 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
88 pt.getBlock(), pt.getPoint());
89 assert(dom && "unable to find valid insertion point");
90#endif // NDEBUG
91 }
92 }
93 return pt;
94}
95
96namespace {
97enum OpConversionMode {
98 /// In this mode, the conversion will ignore failed conversions to allow
99 /// illegal operations to co-exist in the IR.
100 Partial,
101
102 /// In this mode, all operations must be legal for the given target for the
103 /// conversion to succeed.
104 Full,
105
106 /// In this mode, operations are analyzed for legality. No actual rewrites are
107 /// applied to the operations on success.
108 Analysis,
109};
110} // namespace
111
112//===----------------------------------------------------------------------===//
113// ConversionValueMapping
114//===----------------------------------------------------------------------===//
115
116/// A vector of SSA values, optimized for the most common case of a single
117/// value.
119
120namespace {
121
122/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
123struct ValueVectorMapInfo {
124 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
125 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
126 static ::llvm::hash_code getHashValue(const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
128 }
129 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
130 return LHS == RHS;
131 }
132};
133
134/// This class wraps a IRMapping to provide recursive lookup
135/// functionality, i.e. we will traverse if the mapped value also has a mapping.
136struct ConversionValueMapping {
137 /// Return "true" if an SSA value is mapped to the given value. May return
138 /// false positives.
139 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
140
141 /// Lookup a value in the mapping.
142 ValueVector lookup(const ValueVector &from) const;
143
144 template <typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
146
147 /// Map a value vector to the one provided.
148 template <typename OldVal, typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
151 LLVM_DEBUG({
152 ValueVector next(newVal);
153 while (true) {
154 assert(next != oldVal && "inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
157 break;
158 next = it->second;
159 }
160 });
161 mappedTo.insert_range(newVal);
162
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
164 }
165
166 /// Map a value vector or single value to the one provided.
167 template <typename OldVal, typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
173 } else if constexpr (IsValueVector<NewVal>{}) {
174 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
175 } else {
176 map(ValueVector{oldVal}, ValueVector{newVal});
177 }
178 }
179
180 void map(Value oldVal, SmallVector<Value> &&newVal) {
181 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
182 }
183
184 /// Drop the last mapping for the given values.
185 void erase(const ValueVector &value) { mapping.erase(value); }
186
187private:
188 /// Current value mappings.
190
191 /// All SSA values that are mapped to. May contain false positives.
192 DenseSet<Value> mappedTo;
193};
194} // namespace
195
196/// Marker attribute for pure type conversions. I.e., mappings whose only
197/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
198/// the replacement values of a "replaceOp" call, etc., are not pure type
199/// conversions.)
200static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
201
202/// Return the operation that defines all values in the vector. Return nullptr
203/// if the values are not defined by the same operation.
205 assert(!values.empty() && "expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
209 return nullptr;
210 }
211 return op;
212}
213
214/// A vector of values is a pure type conversion if all values are defined by
215/// the same operation and the operation has the `kPureTypeConversionMarker`
216/// attribute.
217static bool isPureTypeConversion(const ValueVector &values) {
218 assert(!values.empty() && "expected non-empty value vector");
219 Operation *op = getCommonDefiningOp(values);
220 return op && op->hasAttr(kPureTypeConversionMarker);
221}
222
223ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
226 // No mapping found: The lookup stops here.
227 return {};
228 }
229 return it->second;
230}
231
232//===----------------------------------------------------------------------===//
233// Rewriter and Translation State
234//===----------------------------------------------------------------------===//
235namespace {
236/// This class contains a snapshot of the current conversion rewriter state.
237/// This is useful when saving and undoing a set of rewrites.
238struct RewriterState {
239 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
243
244 /// The current number of rewrites performed.
245 unsigned numRewrites;
246
247 /// The current number of ignored operations.
248 unsigned numIgnoredOperations;
249
250 /// The current number of replaced ops that are scheduled for erasure.
251 unsigned numReplacedOps;
252};
253
254//===----------------------------------------------------------------------===//
255// IR rewrites
256//===----------------------------------------------------------------------===//
257
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
259
260/// Notify the listener that the given block and its contents are being erased.
261static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
262 for (Operation &op : b)
263 notifyIRErased(listener, op);
264 listener->notifyBlockErased(&b);
265}
266
267/// Notify the listener that the given operation and its contents are being
268/// erased.
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
270 for (Region &r : op.getRegions()) {
271 for (Block &b : r) {
272 notifyIRErased(listener, b);
273 }
274 }
275 listener->notifyOperationErased(&op);
276}
277
278/// An IR rewrite that can be committed (upon success) or rolled back (upon
279/// failure).
280///
281/// The dialect conversion keeps track of IR modifications (requested by the
282/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
283/// are directly applied to the IR as the rewriter API is used, some are applied
284/// partially, and some are delayed until the `IRRewrite` objects are committed.
285class IRRewrite {
286public:
287 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
288 /// Enum values are ordered, so that they can be used in `classof`: first all
289 /// block rewrites, then all operation rewrites.
290 enum class Kind {
291 // Block rewrites
292 CreateBlock,
293 EraseBlock,
294 InlineBlock,
295 MoveBlock,
296 BlockTypeConversion,
297 // Operation rewrites
298 MoveOperation,
299 ModifyOperation,
300 ReplaceOperation,
301 CreateOperation,
302 UnresolvedMaterialization,
303 // Value rewrites
304 ReplaceValue
305 };
306
307 virtual ~IRRewrite() = default;
308
309 /// Roll back the rewrite. Operations may be erased during rollback.
310 virtual void rollback() = 0;
311
312 /// Commit the rewrite. At this point, it is certain that the dialect
313 /// conversion will succeed. All IR modifications, except for operation/block
314 /// erasure, must be performed through the given rewriter.
315 ///
316 /// Instead of erasing operations/blocks, they should merely be unlinked
317 /// commit phase and finally be erased during the cleanup phase. This is
318 /// because internal dialect conversion state (such as `mapping`) may still
319 /// be using them.
320 ///
321 /// Any IR modification that was already performed before the commit phase
322 /// (e.g., insertion of an op) must be communicated to the listener that may
323 /// be attached to the given rewriter.
324 virtual void commit(RewriterBase &rewriter) {}
325
326 /// Cleanup operations/blocks. Cleanup is called after commit.
327 virtual void cleanup(RewriterBase &rewriter) {}
328
329 Kind getKind() const { return kind; }
330
331 static bool classof(const IRRewrite *rewrite) { return true; }
332
333protected:
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
336
337 const ConversionConfig &getConfig() const;
338
339 const Kind kind;
340 ConversionPatternRewriterImpl &rewriterImpl;
341};
342
343/// A block rewrite.
344class BlockRewrite : public IRRewrite {
345public:
346 /// Return the block that this rewrite operates on.
347 Block *getBlock() const { return block; }
348
349 static bool classof(const IRRewrite *rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
352 }
353
354protected:
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
356 Block *block)
357 : IRRewrite(kind, rewriterImpl), block(block) {}
358
359 // The block that this rewrite operates on.
360 Block *block;
361};
362
363/// A value rewrite.
364class ValueRewrite : public IRRewrite {
365public:
366 /// Return the value that this rewrite operates on.
367 Value getValue() const { return value; }
368
369 static bool classof(const IRRewrite *rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
371 }
372
373protected:
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
375 Value value)
376 : IRRewrite(kind, rewriterImpl), value(value) {}
377
378 // The value that this rewrite operates on.
379 Value value;
380};
381
382/// Creation of a block. Block creations are immediately reflected in the IR.
383/// There is no extra work to commit the rewrite. During rollback, the newly
384/// created block is erased.
385class CreateBlockRewrite : public BlockRewrite {
386public:
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
388 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
389
390 static bool classof(const IRRewrite *rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
392 }
393
394 void commit(RewriterBase &rewriter) override {
395 // The block was already created and inserted. Just inform the listener.
396 if (auto *listener = rewriter.getListener())
397 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
398 }
399
400 void rollback() override {
401 // Unlink all of the operations within this block, they will be deleted
402 // separately.
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
408 block->erase();
409 else
410 delete block;
411 }
412};
413
414/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
415/// blocks are immediately unlinked, but only erased during cleanup. This makes
416/// it easier to rollback a block erasure: the block is simply inserted into its
417/// original location.
418class EraseBlockRewrite : public BlockRewrite {
419public:
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
421 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
423
424 static bool classof(const IRRewrite *rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
426 }
427
428 ~EraseBlockRewrite() override {
429 assert(!block &&
430 "rewrite was neither rolled back nor committed/cleaned up");
431 }
432
433 void rollback() override {
434 // The block (owned by this rewrite) was not actually erased yet. It was
435 // just unlinked. Put it back into its original position.
436 assert(block && "expected block");
437 auto &blockList = region->getBlocks();
438 Region::iterator before = insertBeforeBlock
439 ? Region::iterator(insertBeforeBlock)
440 : blockList.end();
441 blockList.insert(before, block);
442 block = nullptr;
443 }
444
445 void commit(RewriterBase &rewriter) override {
446 assert(block && "expected block");
447
448 // Notify the listener that the block and its contents are being erased.
449 if (auto *listener =
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
451 notifyIRErased(listener, *block);
452 }
453
454 void cleanup(RewriterBase &rewriter) override {
455 // Erase the contents of the block.
456 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
457 rewriter.eraseOp(&op);
458 assert(block->empty() && "expected empty block");
459
460 // Erase the block.
461 block->dropAllDefinedValueUses();
462 delete block;
463 block = nullptr;
464 }
465
466private:
467 // The region in which this block was previously contained.
468 Region *region;
469
470 // The original successor of this block before it was unlinked. "nullptr" if
471 // this block was the only block in the region.
472 Block *insertBeforeBlock;
473};
474
475/// Inlining of a block. This rewrite is immediately reflected in the IR.
476/// Note: This rewrite represents only the inlining of the operations. The
477/// erasure of the inlined block is a separate rewrite.
478class InlineBlockRewrite : public BlockRewrite {
479public:
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
481 Block *sourceBlock, Block::iterator before)
482 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ? nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
487 // If a listener is attached to the dialect conversion, ops must be moved
488 // one-by-one. When they are moved in bulk, notifications cannot be sent
489 // because the ops that used to be in the source block at the time of the
490 // inlining (before the "commit" phase) are unknown at the time when
491 // notifications are sent (which is during the "commit" phase).
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
494 }
495
496 static bool classof(const IRRewrite *rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
498 }
499
500 void rollback() override {
501 // Put the operations from the destination block (owned by the rewrite)
502 // back into the source block.
503 if (firstInlinedInst) {
504 assert(lastInlinedInst && "expected operation");
505 sourceBlock->getOperations().splice(sourceBlock->begin(),
506 block->getOperations(),
507 Block::iterator(firstInlinedInst),
508 ++Block::iterator(lastInlinedInst));
509 }
510 }
511
512private:
513 // The block that originally contained the operations.
514 Block *sourceBlock;
515
516 // The first inlined operation.
517 Operation *firstInlinedInst;
518
519 // The last inlined operation.
520 Operation *lastInlinedInst;
521};
522
523/// Moving of a block. This rewrite is immediately reflected in the IR.
524class MoveBlockRewrite : public BlockRewrite {
525public:
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
527 Region *previousRegion, Region::iterator previousIt)
528 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
531 : &*previousIt) {}
532
533 static bool classof(const IRRewrite *rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
535 }
536
537 void commit(RewriterBase &rewriter) override {
538 // The block was already moved. Just inform the listener.
539 if (auto *listener = rewriter.getListener()) {
540 // Note: `previousIt` cannot be passed because this is a delayed
541 // notification and iterators into past IR state cannot be represented.
542 listener->notifyBlockInserted(block, /*previous=*/region,
543 /*previousIt=*/{});
544 }
545 }
546
547 void rollback() override {
548 // Move the block back to its original position.
549 Region::iterator before =
550 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
551 if (Region *currentParent = block->getParent()) {
552 // Block is still in a region, use cheap splice to move it back.
553 region->getBlocks().splice(before, currentParent->getBlocks(), block);
554 return;
555 }
556 // Block was orphaned by a prior rollback, can't splice.
557 region->getBlocks().insert(before, block);
558 }
559
560private:
561 // The region in which this block was previously contained.
562 Region *region;
563
564 // The original successor of this block before it was moved. "nullptr" if
565 // this block was the only block in the region.
566 Block *insertBeforeBlock;
567};
568
569/// Block type conversion. This rewrite is partially reflected in the IR.
570class BlockTypeConversionRewrite : public BlockRewrite {
571public:
572 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
573 Block *origBlock, Block *newBlock)
574 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
575 newBlock(newBlock) {}
576
577 static bool classof(const IRRewrite *rewrite) {
578 return rewrite->getKind() == Kind::BlockTypeConversion;
579 }
580
581 Block *getOrigBlock() const { return block; }
582
583 Block *getNewBlock() const { return newBlock; }
584
585 void commit(RewriterBase &rewriter) override;
586
587 void rollback() override;
588
589private:
590 /// The new block that was created as part of this signature conversion.
591 Block *newBlock;
592};
593
594/// Replacing a value. This rewrite is not immediately reflected in the
595/// IR. An internal IR mapping is updated, but the actual replacement is delayed
596/// until the rewrite is committed.
597class ReplaceValueRewrite : public ValueRewrite {
598public:
599 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
600 const TypeConverter *converter)
601 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
602 converter(converter) {}
603
604 static bool classof(const IRRewrite *rewrite) {
605 return rewrite->getKind() == Kind::ReplaceValue;
606 }
607
608 void commit(RewriterBase &rewriter) override;
609
610 void rollback() override;
611
612private:
613 /// The current type converter when the value was replaced.
614 const TypeConverter *converter;
615};
616
617/// An operation rewrite.
618class OperationRewrite : public IRRewrite {
619public:
620 /// Return the operation that this rewrite operates on.
621 Operation *getOperation() const { return op; }
622
623 static bool classof(const IRRewrite *rewrite) {
624 return rewrite->getKind() >= Kind::MoveOperation &&
625 rewrite->getKind() <= Kind::UnresolvedMaterialization;
626 }
627
628protected:
629 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
630 Operation *op)
631 : IRRewrite(kind, rewriterImpl), op(op) {}
632
633 // The operation that this rewrite operates on.
634 Operation *op;
635};
636
637/// Moving of an operation. This rewrite is immediately reflected in the IR.
638class MoveOperationRewrite : public OperationRewrite {
639public:
640 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
641 Operation *op, OpBuilder::InsertPoint previous)
642 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
643 block(previous.getBlock()),
644 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
645 ? nullptr
646 : &*previous.getPoint()) {}
647
648 static bool classof(const IRRewrite *rewrite) {
649 return rewrite->getKind() == Kind::MoveOperation;
650 }
651
652 void commit(RewriterBase &rewriter) override {
653 // The operation was already moved. Just inform the listener.
654 if (auto *listener = rewriter.getListener()) {
655 // Note: `previousIt` cannot be passed because this is a delayed
656 // notification and iterators into past IR state cannot be represented.
657 listener->notifyOperationInserted(
658 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
659 /*insertPt=*/{}));
660 }
661 }
662
663 void rollback() override {
664 // Move the operation back to its original position.
665 Block::iterator before =
666 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
667 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
668 }
669
670private:
671 // The block in which this operation was previously contained.
672 Block *block;
673
674 // The original successor of this operation before it was moved. "nullptr"
675 // if this operation was the only operation in the region.
676 Operation *insertBeforeOp;
677};
678
679/// In-place modification of an op. This rewrite is immediately reflected in
680/// the IR. The previous state of the operation is stored in this object.
681class ModifyOperationRewrite : public OperationRewrite {
682public:
683 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
684 Operation *op)
685 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
686 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
687 operands(op->operand_begin(), op->operand_end()),
688 successors(op->successor_begin(), op->successor_end()) {
689 if (OpaqueProperties prop = op->getPropertiesStorage()) {
690 // Make a copy of the properties.
691 propertiesStorage = operator new(op->getPropertiesStorageSize());
692 OpaqueProperties propCopy(propertiesStorage);
693 name.initOpProperties(propCopy, /*init=*/prop);
694 }
695 }
696
697 static bool classof(const IRRewrite *rewrite) {
698 return rewrite->getKind() == Kind::ModifyOperation;
699 }
700
701 ~ModifyOperationRewrite() override {
702 assert(!propertiesStorage &&
703 "rewrite was neither committed nor rolled back");
704 }
705
706 void commit(RewriterBase &rewriter) override {
707 // Notify the listener that the operation was modified in-place.
708 if (auto *listener =
709 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
710 listener->notifyOperationModified(op);
711
712 if (propertiesStorage) {
713 OpaqueProperties propCopy(propertiesStorage);
714 // Note: The operation may have been erased in the mean time, so
715 // OperationName must be stored in this object.
716 name.destroyOpProperties(propCopy);
717 operator delete(propertiesStorage);
718 propertiesStorage = nullptr;
719 }
720 }
721
722 void rollback() override {
723 op->setLoc(loc);
724 op->setAttrs(attrs);
725 op->setOperands(operands);
726 for (const auto &it : llvm::enumerate(successors))
727 op->setSuccessor(it.value(), it.index());
728 if (propertiesStorage) {
729 OpaqueProperties propCopy(propertiesStorage);
730 op->copyProperties(propCopy);
731 name.destroyOpProperties(propCopy);
732 operator delete(propertiesStorage);
733 propertiesStorage = nullptr;
734 }
735 }
736
737private:
738 OperationName name;
739 LocationAttr loc;
740 DictionaryAttr attrs;
741 SmallVector<Value, 8> operands;
742 SmallVector<Block *, 2> successors;
743 void *propertiesStorage = nullptr;
744};
745
746/// Replacing an operation. Erasing an operation is treated as a special case
747/// with "null" replacements. This rewrite is not immediately reflected in the
748/// IR. An internal IR mapping is updated, but values are not replaced and the
749/// original op is not erased until the rewrite is committed.
750class ReplaceOperationRewrite : public OperationRewrite {
751public:
752 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
753 Operation *op, const TypeConverter *converter)
754 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
755 converter(converter) {}
756
757 static bool classof(const IRRewrite *rewrite) {
758 return rewrite->getKind() == Kind::ReplaceOperation;
759 }
760
761 void commit(RewriterBase &rewriter) override;
762
763 void rollback() override;
764
765 void cleanup(RewriterBase &rewriter) override;
766
767private:
768 /// An optional type converter that can be used to materialize conversions
769 /// between the new and old values if necessary.
770 const TypeConverter *converter;
771};
772
773class CreateOperationRewrite : public OperationRewrite {
774public:
775 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
776 Operation *op)
777 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
778
779 static bool classof(const IRRewrite *rewrite) {
780 return rewrite->getKind() == Kind::CreateOperation;
781 }
782
783 void commit(RewriterBase &rewriter) override {
784 // The operation was already created and inserted. Just inform the listener.
785 if (auto *listener = rewriter.getListener())
786 listener->notifyOperationInserted(op, /*previous=*/{});
787 }
788
789 void rollback() override;
790};
791
792/// The type of materialization.
793enum MaterializationKind {
794 /// This materialization materializes a conversion from an illegal type to a
795 /// legal one.
796 Target,
797
798 /// This materialization materializes a conversion from a legal type back to
799 /// an illegal one.
800 Source
801};
802
803/// Helper class that stores metadata about an unresolved materialization.
804class UnresolvedMaterializationInfo {
805public:
806 UnresolvedMaterializationInfo() = default;
807 UnresolvedMaterializationInfo(const TypeConverter *converter,
808 MaterializationKind kind, Type originalType)
809 : converterAndKind(converter, kind), originalType(originalType) {}
810
811 /// Return the type converter of this materialization (which may be null).
812 const TypeConverter *getConverter() const {
813 return converterAndKind.getPointer();
814 }
815
816 /// Return the kind of this materialization.
817 MaterializationKind getMaterializationKind() const {
818 return converterAndKind.getInt();
819 }
820
821 /// Return the original type of the SSA value.
822 Type getOriginalType() const { return originalType; }
823
824private:
825 /// The corresponding type converter to use when resolving this
826 /// materialization, and the kind of this materialization.
827 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
828 converterAndKind;
829
830 /// The original type of the SSA value. Only used for target
831 /// materializations.
832 Type originalType;
833};
834
835/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
836/// op. Unresolved materializations fold away or are replaced with
837/// source/target materializations at the end of the dialect conversion.
838class UnresolvedMaterializationRewrite : public OperationRewrite {
839public:
840 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
841 UnrealizedConversionCastOp op,
842 ValueVector mappedValues)
843 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
844 mappedValues(std::move(mappedValues)) {}
845
846 static bool classof(const IRRewrite *rewrite) {
847 return rewrite->getKind() == Kind::UnresolvedMaterialization;
848 }
849
850 void rollback() override;
851
852 UnrealizedConversionCastOp getOperation() const {
853 return cast<UnrealizedConversionCastOp>(op);
854 }
855
856private:
857 /// The values in the conversion value mapping that are being replaced by the
858 /// results of this unresolved materialization.
859 ValueVector mappedValues;
860};
861} // namespace
862
863#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
864/// Return "true" if there is an operation rewrite that matches the specified
865/// rewrite type and operation among the given rewrites.
866template <typename RewriteTy, typename R>
867static bool hasRewrite(R &&rewrites, Operation *op) {
868 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
869 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
870 return rewriteTy && rewriteTy->getOperation() == op;
871 });
872}
873
874/// Return "true" if there is a block rewrite that matches the specified
875/// rewrite type and block among the given rewrites.
876template <typename RewriteTy, typename R>
877static bool hasRewrite(R &&rewrites, Block *block) {
878 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
879 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
880 return rewriteTy && rewriteTy->getBlock() == block;
881 });
882}
883#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
884
885//===----------------------------------------------------------------------===//
886// ConversionPatternRewriterImpl
887//===----------------------------------------------------------------------===//
888namespace mlir {
889namespace detail {
891 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
892 const ConversionConfig &config,
896
897 //===--------------------------------------------------------------------===//
898 // State Management
899 //===--------------------------------------------------------------------===//
900
901 /// Return the current state of the rewriter.
902 RewriterState getCurrentState();
903
904 /// Apply all requested operation rewrites. This method is invoked when the
905 /// conversion process succeeds.
906 void applyRewrites();
907
908 /// Reset the state of the rewriter to a previously saved point. Optionally,
909 /// the name of the pattern that triggered the rollback can specified for
910 /// debugging purposes.
911 void resetState(RewriterState state, StringRef patternName = "");
912
913 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
914 /// failure.
915 template <typename RewriteTy, typename... Args>
916 void appendRewrite(Args &&...args) {
917 assert(config.allowPatternRollback && "appending rewrites is not allowed");
918 rewrites.push_back(
919 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
920 }
921
922 /// Undo the rewrites (motions, splits) one by one in reverse order until
923 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
924 /// that triggered the rollback can specified for debugging purposes.
925 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
926
927 /// Remap the given values to those with potentially different types. Returns
928 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
929 /// is the tag used when describing a value within a diagnostic, e.g.
930 /// "operand".
931 LogicalResult remapValues(StringRef valueDiagTag,
932 std::optional<Location> inputLoc, ValueRange values,
933 SmallVector<ValueVector> &remapped);
934
935 /// Return "true" if the given operation is ignored, and does not need to be
936 /// converted.
937 bool isOpIgnored(Operation *op) const;
938
939 /// Return "true" if the given operation was replaced or erased.
940 bool wasOpReplaced(Operation *op) const;
941
942 /// Lookup the most recently mapped values with the desired types in the
943 /// mapping, taking into account only replacements. Perform a best-effort
944 /// search for existing materializations with the desired types.
945 ///
946 /// If `skipPureTypeConversions` is "true", materializations that are pure
947 /// type conversions are not considered.
948 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
949 bool skipPureTypeConversions = false) const;
950
951 /// Lookup the given value within the map, or return an empty vector if the
952 /// value is not mapped. If it is mapped, this follows the same behavior
953 /// as `lookupOrDefault`.
954 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
955
956 //===--------------------------------------------------------------------===//
957 // IR Rewrites / Type Conversion
958 //===--------------------------------------------------------------------===//
959
960 /// Convert the types of block arguments within the given region.
961 FailureOr<Block *>
962 convertRegionTypes(Region *region, const TypeConverter &converter,
963 TypeConverter::SignatureConversion *entryConversion);
964
965 /// Apply the given signature conversion on the given block. The new block
966 /// containing the updated signature is returned. If no conversions were
967 /// necessary, e.g. if the block has no arguments, `block` is returned.
968 /// `converter` is used to generate any necessary cast operations that
969 /// translate between the origin argument types and those specified in the
970 /// signature conversion.
971 Block *applySignatureConversion(
972 Block *block, const TypeConverter *converter,
973 TypeConverter::SignatureConversion &signatureConversion);
974
975 /// Replace the results of the given operation with the given values and
976 /// erase the operation.
977 ///
978 /// There can be multiple replacement values for each result (1:N
979 /// replacement). If the replacement values are empty, the respective result
980 /// is dropped and a source materialization is built if the result still has
981 /// uses.
982 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
983
984 /// Replace the uses of the given value with the given values. The specified
985 /// converter is used to build materializations (if necessary). If `functor`
986 /// is specified, only the uses that the functor returns "true" for are
987 /// replaced.
988 void replaceValueUses(Value from, ValueRange to,
989 const TypeConverter *converter,
990 function_ref<bool(OpOperand &)> functor = nullptr);
991
992 /// Erase the given block and its contents.
993 void eraseBlock(Block *block);
994
995 /// Inline the source block into the destination block before the given
996 /// iterator.
997 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
998
999 //===--------------------------------------------------------------------===//
1000 // Materializations
1001 //===--------------------------------------------------------------------===//
1002
1003 /// Build an unresolved materialization operation given a range of output
1004 /// types and a list of input operands. Returns the inputs if they their
1005 /// types match the output types.
1006 ///
1007 /// If a cast op was built, it can optionally be returned with the `castOp`
1008 /// output argument.
1009 ///
1010 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1011 /// the results of the unresolved materialization in the conversion value
1012 /// mapping.
1013 ///
1014 /// If `isPureTypeConversion` is "true", the materialization is created only
1015 /// to resolve a type mismatch. That means it is not a regular value
1016 /// replacement issued by the user. (Replacement values that are created
1017 /// "out of thin air" appear like unresolved materializations because they are
1018 /// unrealized_conversion_cast ops. However, they must be treated like
1019 /// regular value replacements.)
1020 ValueRange buildUnresolvedMaterialization(
1021 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1022 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1023 Type originalType, const TypeConverter *converter,
1024 bool isPureTypeConversion = true);
1025
1026 /// Find a replacement value for the given SSA value in the conversion value
1027 /// mapping. The replacement value must have the same type as the given SSA
1028 /// value. If there is no replacement value with the correct type, find the
1029 /// latest replacement value (regardless of the type) and build a source
1030 /// materialization.
1031 Value findOrBuildReplacementValue(Value value,
1032 const TypeConverter *converter);
1033
1034 //===--------------------------------------------------------------------===//
1035 // Rewriter Notification Hooks
1036 //===--------------------------------------------------------------------===//
1037
1038 //// Notifies that an op was inserted.
1039 void notifyOperationInserted(Operation *op,
1040 OpBuilder::InsertPoint previous) override;
1041
1042 /// Notifies that a block was inserted.
1043 void notifyBlockInserted(Block *block, Region *previous,
1044 Region::iterator previousIt) override;
1045
1046 /// Notifies that a pattern match failed for the given reason.
1047 void
1048 notifyMatchFailure(Location loc,
1049 function_ref<void(Diagnostic &)> reasonCallback) override;
1050
1051 //===--------------------------------------------------------------------===//
1052 // IR Erasure
1053 //===--------------------------------------------------------------------===//
1054
1055 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1056 /// operation or block is erased multiple times. This rewriter assumes that
1057 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1059 public:
1062 std::function<void(Operation *)> opErasedCallback = nullptr)
1063 : RewriterBase(context, /*listener=*/this),
1064 opErasedCallback(std::move(opErasedCallback)) {}
1065
1066 /// Erase the given op (unless it was already erased).
1067 void eraseOp(Operation *op) override {
1068 if (wasErased(op))
1069 return;
1070 op->dropAllUses();
1072 }
1073
1074 /// Erase the given block (unless it was already erased).
1075 void eraseBlock(Block *block) override {
1076 if (wasErased(block))
1077 return;
1078 assert(block->empty() && "expected empty block");
1079 block->dropAllDefinedValueUses();
1081 }
1082
1083 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1084
1086 erased.insert(op);
1087 if (opErasedCallback)
1088 opErasedCallback(op);
1089 }
1090
1091 void notifyBlockErased(Block *block) override { erased.insert(block); }
1092
1093 private:
1094 /// Pointers to all erased operations and blocks.
1095 DenseSet<void *> erased;
1096
1097 /// A callback that is invoked when an operation is erased.
1098 std::function<void(Operation *)> opErasedCallback;
1099 };
1100
1101 //===--------------------------------------------------------------------===//
1102 // State
1103 //===--------------------------------------------------------------------===//
1104
1105 /// The rewriter that is used to perform the conversion.
1106 ConversionPatternRewriter &rewriter;
1107
1108 // Mapping between replaced values that differ in type. This happens when
1109 // replacing a value with one of a different type.
1110 ConversionValueMapping mapping;
1111
1112 /// Ordered list of block operations (creations, splits, motions).
1113 /// This vector is maintained only if `allowPatternRollback` is set to
1114 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1115 /// bookkeeping is needed.
1117
1118 /// A set of operations that should no longer be considered for legalization.
1119 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1120 /// tracked separately.
1122
1123 /// A set of operations that were replaced/erased. Such ops are not erased
1124 /// immediately but only when the dialect conversion succeeds. In the mean
1125 /// time, they should no longer be considered for legalization and any attempt
1126 /// to modify/access them is invalid rewriter API usage.
1128
1129 /// A set of operations that were created by the current pattern.
1131
1132 /// A set of operations that were modified by the current pattern.
1134
1135 /// A list of unresolved materializations that were created by the current
1136 /// pattern.
1138
1139 /// A mapping for looking up metadata of unresolved materializations.
1142
1143 /// The current type converter, or nullptr if no type converter is currently
1144 /// active.
1146
1147 /// A mapping of regions to type converters that should be used when
1148 /// converting the arguments of blocks within that region.
1150
1151 /// Dialect conversion configuration.
1152 const ConversionConfig &config;
1153
1154 /// The operation converter to use for recursive legalization.
1156
1157 /// A set of erased operations. This set is utilized only if
1158 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1159 /// similar to `replacedOps` (which is maintained when the flag is set to
1160 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1161 /// from a SetVector.
1163
1164 /// A set of erased blocks. This set is utilized only if
1165 /// `allowPatternRollback` is set to "false".
1167
1168 /// A rewriter that notifies the listener (if any) about all IR
1169 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1170 /// is set to "false". If the flag is set to "true", the listener is notified
1171 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1173
1174#ifndef NDEBUG
1175 /// A set of replaced values. This set is for debugging purposes only and it
1176 /// is maintained only if `allowPatternRollback` is set to "true".
1178
1179 /// A set of operations that have pending updates. This tracking isn't
1180 /// strictly necessary, and is thus only active during debug builds for extra
1181 /// verification.
1183
1184 /// A raw output stream used to prefix the debug log.
1185 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1186 llvm::dbgs()};
1187
1188 /// A logger used to emit diagnostics during the conversion process.
1189 llvm::ScopedPrinter logger{os};
1190 std::string logPrefix;
1191#endif
1192};
1193} // namespace detail
1194} // namespace mlir
1195
1196const ConversionConfig &IRRewrite::getConfig() const {
1197 return rewriterImpl.config;
1198}
1199
1200void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1201 // Inform the listener about all IR modifications that have already taken
1202 // place: References to the original block have been replaced with the new
1203 // block.
1204 if (auto *listener =
1205 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1206 for (Operation *op : getNewBlock()->getUsers())
1207 listener->notifyOperationModified(op);
1208}
1209
1210void BlockTypeConversionRewrite::rollback() {
1211 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1212}
1213
1214/// Replace all uses of `from` with `repl`.
1215static void
1217 function_ref<bool(OpOperand &)> functor = nullptr) {
1218 if (isa<BlockArgument>(repl)) {
1219 // `repl` is a block argument. Directly replace all uses.
1220 if (functor) {
1221 rewriter.replaceUsesWithIf(from, repl, functor);
1222 } else {
1223 rewriter.replaceAllUsesWith(from, repl);
1224 }
1225 return;
1226 }
1227
1228 // If the replacement value is an operation, only replace those uses that:
1229 // - are in a different block than the replacement operation, or
1230 // - are in the same block but after the replacement operation.
1231 //
1232 // Example:
1233 // ^bb0(%arg0: i32):
1234 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1235 // "another_consumer"(%arg0) : (i32) -> ()
1236 //
1237 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1238 // use in "another_consumer" but not the use in "consumer". When using the
1239 // normal RewriterBase API, this would typically be done with
1240 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1241 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1242 // it cannot be supported efficiently with `allowPatternRollback` set to
1243 // "true". Therefore, the conversion driver is trying to be smart and replaces
1244 // only those uses that do not lead to a dominance violation. E.g., the
1245 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1246 // behavior.
1247 //
1248 // TODO: As we move more and more towards `allowPatternRollback` set to
1249 // "false", we should remove this special handling, in order to align the
1250 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1251 Operation *replOp = repl.getDefiningOp();
1252 Block *replBlock = replOp->getBlock();
1253 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1254 Operation *user = operand.getOwner();
1255 bool result =
1256 user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1257 if (result && functor)
1258 result &= functor(operand);
1259 return result;
1260 });
1261}
1262
1263void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1264 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1265 if (!repl)
1266 return;
1267 performReplaceValue(rewriter, value, repl);
1268}
1269
1270void ReplaceValueRewrite::rollback() {
1271 rewriterImpl.mapping.erase({value});
1272#ifndef NDEBUG
1273 rewriterImpl.replacedValues.erase(value);
1274#endif // NDEBUG
1275}
1276
1277void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1278 auto *listener =
1279 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1280
1281 // Compute replacement values.
1282 SmallVector<Value> replacements =
1283 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1284 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1285 });
1286
1287 // Notify the listener that the operation is about to be replaced.
1288 if (listener)
1289 listener->notifyOperationReplaced(op, replacements);
1290
1291 // Replace all uses with the new values.
1292 for (auto [result, newValue] :
1293 llvm::zip_equal(op->getResults(), replacements))
1294 if (newValue)
1295 rewriter.replaceAllUsesWith(result, newValue);
1296
1297 // The original op will be erased, so remove it from the set of unlegalized
1298 // ops.
1299 if (getConfig().unlegalizedOps)
1300 getConfig().unlegalizedOps->erase(op);
1301
1302 // Notify the listener that the operation and its contents are being erased.
1303 if (listener)
1304 notifyIRErased(listener, *op);
1305
1306 // Do not erase the operation yet. It may still be referenced in `mapping`.
1307 // Just unlink it for now and erase it during cleanup.
1308 op->getBlock()->getOperations().remove(op);
1309}
1310
1311void ReplaceOperationRewrite::rollback() {
1312 for (auto result : op->getResults())
1313 rewriterImpl.mapping.erase({result});
1314}
1315
1316void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1317 rewriter.eraseOp(op);
1318}
1319
1320void CreateOperationRewrite::rollback() {
1321 for (Region &region : op->getRegions()) {
1322 while (!region.getBlocks().empty())
1323 region.getBlocks().remove(region.getBlocks().begin());
1324 }
1325 op->dropAllUses();
1326 op->erase();
1327}
1328
1329void UnresolvedMaterializationRewrite::rollback() {
1330 if (!mappedValues.empty())
1331 rewriterImpl.mapping.erase(mappedValues);
1332 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1333 op->erase();
1334}
1335
1337 // Commit all rewrites. Use a new rewriter, so the modifications are not
1338 // tracked for rollback purposes etc.
1339 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1340 // Note: New rewrites may be added during the "commit" phase and the
1341 // `rewrites` vector may reallocate.
1342 for (size_t i = 0; i < rewrites.size(); ++i)
1343 rewrites[i]->commit(irRewriter);
1344
1345 // Clean up all rewrites.
1346 SingleEraseRewriter eraseRewriter(
1347 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1348 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1349 unresolvedMaterializations.erase(castOp);
1350 });
1351 for (auto &rewrite : rewrites)
1352 rewrite->cleanup(eraseRewriter);
1353}
1354
1355//===----------------------------------------------------------------------===//
1356// State Management
1357//===----------------------------------------------------------------------===//
1358
1360 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1361 // Helper function that looks up a single value.
1362 auto lookup = [&](const ValueVector &values) -> ValueVector {
1363 assert(!values.empty() && "expected non-empty value vector");
1364
1365 // If the pattern rollback is enabled, use the mapping to look up the
1366 // values.
1367 if (config.allowPatternRollback)
1368 return mapping.lookup(values);
1369
1370 // Otherwise, look up values by examining the IR. All replacements have
1371 // already been materialized in IR.
1372 Operation *op = getCommonDefiningOp(values);
1373 if (!op)
1374 return {};
1375 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1376 if (!castOp)
1377 return {};
1378 if (!this->unresolvedMaterializations.contains(castOp))
1379 return {};
1380 if (castOp.getOutputs() != values)
1381 return {};
1382 return castOp.getInputs();
1383 };
1384
1385 // Helper function that looks up each value in `values` individually and then
1386 // composes the results. If that fails, it tries to look up the entire vector
1387 // at once.
1388 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1389 // If possible, replace each value with (one or multiple) mapped values.
1390 ValueVector next;
1391 for (Value v : values) {
1392 ValueVector r = lookup({v});
1393 if (!r.empty()) {
1394 llvm::append_range(next, r);
1395 } else {
1396 next.push_back(v);
1397 }
1398 }
1399 if (next != values) {
1400 // At least one value was replaced.
1401 return next;
1402 }
1403
1404 // Otherwise: Check if there is a mapping for the entire vector. Such
1405 // mappings are materializations. (N:M mapping are not supported for value
1406 // replacements.)
1407 //
1408 // Note: From a correctness point of view, materializations do not have to
1409 // be stored (and looked up) in the mapping. But for performance reasons,
1410 // we choose to reuse existing IR (when possible) instead of creating it
1411 // multiple times.
1412 ValueVector r = lookup(values);
1413 if (r.empty()) {
1414 // No mapping found: The lookup stops here.
1415 return {};
1416 }
1417 return r;
1418 };
1419
1420 // Try to find the deepest values that have the desired types. If there is no
1421 // such mapping, simply return the deepest values.
1422 ValueVector desiredValue;
1423 ValueVector current{from};
1424 ValueVector lastNonMaterialization{from};
1425 do {
1426 // Store the current value if the types match.
1427 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1428 if (skipPureTypeConversions) {
1429 // Skip pure type conversions, if requested.
1430 bool pureConversion = isPureTypeConversion(current);
1431 match &= !pureConversion;
1432 // Keep track of the last mapped value that was not a pure type
1433 // conversion.
1434 if (!pureConversion)
1435 lastNonMaterialization = current;
1436 }
1437 if (match)
1438 desiredValue = current;
1439
1440 // Lookup next value in the mapping.
1441 ValueVector next = composedLookup(current);
1442 if (next.empty())
1443 break;
1444 current = std::move(next);
1445 } while (true);
1446
1447 // If the desired values were found use them, otherwise default to the leaf
1448 // values. (Skip pure type conversions, if requested.)
1449 if (!desiredTypes.empty())
1450 return desiredValue;
1451 if (skipPureTypeConversions)
1452 return lastNonMaterialization;
1453 return current;
1454}
1455
1458 TypeRange desiredTypes) const {
1459 ValueVector result = lookupOrDefault(from, desiredTypes);
1460 if (result == ValueVector{from} ||
1461 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1462 return {};
1463 return result;
1464}
1465
1467 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1468}
1469
1471 StringRef patternName) {
1472 // Undo any rewrites.
1473 undoRewrites(state.numRewrites, patternName);
1474
1475 // Pop all of the recorded ignored operations that are no longer valid.
1476 while (ignoredOps.size() != state.numIgnoredOperations)
1477 ignoredOps.pop_back();
1478
1479 while (replacedOps.size() != state.numReplacedOps)
1480 replacedOps.pop_back();
1481}
1482
1483void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1484 StringRef patternName) {
1485 for (auto &rewrite :
1486 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1487 rewrite->rollback();
1488 rewrites.resize(numRewritesToKeep);
1489}
1490
1492 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1493 SmallVector<ValueVector> &remapped) {
1494 remapped.reserve(llvm::size(values));
1495
1496 for (const auto &it : llvm::enumerate(values)) {
1497 Value operand = it.value();
1498 Type origType = operand.getType();
1499 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1500
1501 if (!currentTypeConverter) {
1502 // The current pattern does not have a type converter. Pass the most
1503 // recently mapped values, excluding materializations. Materializations
1504 // are intentionally excluded because their presence may depend on other
1505 // patterns. Including materializations would make the lookup fragile
1506 // and unpredictable.
1507 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1508 /*skipPureTypeConversions=*/true));
1509 continue;
1510 }
1511
1512 // If there is no legal conversion, fail to match this pattern.
1513 SmallVector<Type, 1> legalTypes;
1514 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1515 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1516 diag << "unable to convert type for " << valueDiagTag << " #"
1517 << it.index() << ", type was " << origType;
1518 });
1519 return failure();
1520 }
1521 // If a type is converted to 0 types, there is nothing to do.
1522 if (legalTypes.empty()) {
1523 remapped.push_back({});
1524 continue;
1525 }
1526
1527 ValueVector repl = lookupOrDefault(operand, legalTypes);
1528 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1529 // Mapped values have the correct type or there is an existing
1530 // materialization. Or the operand is not mapped at all and has the
1531 // correct type.
1532 remapped.push_back(std::move(repl));
1533 continue;
1534 }
1535
1536 // Create a materialization for the most recently mapped values.
1537 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1538 /*skipPureTypeConversions=*/true);
1540 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1541 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1542 /*originalType=*/origType, currentTypeConverter);
1543 remapped.push_back(castValues);
1544 }
1545 return success();
1546}
1547
1549 // Check to see if this operation is ignored or was replaced.
1550 return wasOpReplaced(op) || ignoredOps.count(op);
1551}
1552
1554 // Check to see if this operation was replaced.
1555 return replacedOps.count(op) || erasedOps.count(op);
1556}
1557
1558//===----------------------------------------------------------------------===//
1559// Type Conversion
1560//===----------------------------------------------------------------------===//
1561
1563 Region *region, const TypeConverter &converter,
1564 TypeConverter::SignatureConversion *entryConversion) {
1565 regionToConverter[region] = &converter;
1566 if (region->empty())
1567 return nullptr;
1568
1569 // Convert the arguments of each non-entry block within the region.
1570 for (Block &block :
1571 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1572 // Compute the signature for the block with the provided converter.
1573 std::optional<TypeConverter::SignatureConversion> conversion =
1574 converter.convertBlockSignature(&block);
1575 if (!conversion)
1576 return failure();
1577 // Convert the block with the computed signature.
1578 applySignatureConversion(&block, &converter, *conversion);
1579 }
1580
1581 // Convert the entry block. If an entry signature conversion was provided,
1582 // use that one. Otherwise, compute the signature with the type converter.
1583 if (entryConversion)
1584 return applySignatureConversion(&region->front(), &converter,
1585 *entryConversion);
1586 std::optional<TypeConverter::SignatureConversion> conversion =
1587 converter.convertBlockSignature(&region->front());
1588 if (!conversion)
1589 return failure();
1590 return applySignatureConversion(&region->front(), &converter, *conversion);
1591}
1592
1594 Block *block, const TypeConverter *converter,
1595 TypeConverter::SignatureConversion &signatureConversion) {
1596#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1597 // A block cannot be converted multiple times.
1598 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1599 llvm::report_fatal_error("block was already converted");
1600#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1601
1603
1604 // If no arguments are being changed or added, there is nothing to do.
1605 unsigned origArgCount = block->getNumArguments();
1606 auto convertedTypes = signatureConversion.getConvertedTypes();
1607 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1608 return block;
1609
1610 // Compute the locations of all block arguments in the new block.
1611 SmallVector<Location> newLocs(convertedTypes.size(),
1612 rewriter.getUnknownLoc());
1613 for (unsigned i = 0; i < origArgCount; ++i) {
1614 auto inputMap = signatureConversion.getInputMapping(i);
1615 if (!inputMap || inputMap->replacedWithValues())
1616 continue;
1617 Location origLoc = block->getArgument(i).getLoc();
1618 for (unsigned j = 0; j < inputMap->size; ++j)
1619 newLocs[inputMap->inputNo + j] = origLoc;
1620 }
1621
1622 // Insert a new block with the converted block argument types and move all ops
1623 // from the old block to the new block.
1624 Block *newBlock =
1625 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1626 convertedTypes, newLocs);
1627
1628 // If a listener is attached to the dialect conversion, ops cannot be moved
1629 // to the destination block in bulk ("fast path"). This is because at the time
1630 // the notifications are sent, it is unknown which ops were moved. Instead,
1631 // ops should be moved one-by-one ("slow path"), so that a separate
1632 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1633 // a bit more efficient, so we try to do that when possible.
1634 bool fastPath = !config.listener;
1635 if (fastPath) {
1636 if (config.allowPatternRollback)
1637 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1638 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1639 } else {
1640 while (!block->empty())
1641 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1642 }
1643
1644 // Replace all uses of the old block with the new block.
1645 block->replaceAllUsesWith(newBlock);
1646
1647 for (unsigned i = 0; i != origArgCount; ++i) {
1648 BlockArgument origArg = block->getArgument(i);
1649 Type origArgType = origArg.getType();
1650
1651 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1652 signatureConversion.getInputMapping(i);
1653 if (!inputMap) {
1654 // This block argument was dropped and no replacement value was provided.
1655 // Materialize a replacement value "out of thin air".
1656 // Note: Materialization must be built here because we cannot find a
1657 // valid insertion point in the new block. (Will point to the old block.)
1658 Value mat =
1660 MaterializationKind::Source,
1661 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1662 origArg.getLoc(),
1663 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1664 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1665 /*isPureTypeConversion=*/false)
1666 .front();
1667 replaceValueUses(origArg, mat, converter);
1668 continue;
1669 }
1670
1671 if (inputMap->replacedWithValues()) {
1672 // This block argument was dropped and replacement values were provided.
1673 assert(inputMap->size == 0 &&
1674 "invalid to provide a replacement value when the argument isn't "
1675 "dropped");
1676 replaceValueUses(origArg, inputMap->replacementValues, converter);
1677 continue;
1678 }
1679
1680 // This is a 1->1+ mapping.
1681 auto replArgs =
1682 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1683 replaceValueUses(origArg, replArgs, converter);
1684 }
1685
1686 if (config.allowPatternRollback)
1687 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1688
1689 // Erase the old block. (It is just unlinked for now and will be erased during
1690 // cleanup.)
1691 rewriter.eraseBlock(block);
1692
1693 return newBlock;
1694}
1695
1696//===----------------------------------------------------------------------===//
1697// Materializations
1698//===----------------------------------------------------------------------===//
1699
1700/// Build an unresolved materialization operation given an output type and set
1701/// of input operands.
1703 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1704 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1705 Type originalType, const TypeConverter *converter,
1706 bool isPureTypeConversion) {
1707 assert((!originalType || kind == MaterializationKind::Target) &&
1708 "original type is valid only for target materializations");
1709 assert(TypeRange(inputs) != outputTypes &&
1710 "materialization is not necessary");
1711
1712 // Create an unresolved materialization. We use a new OpBuilder to avoid
1713 // tracking the materialization like we do for other operations.
1714 OpBuilder builder(outputTypes.front().getContext());
1715 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1716 UnrealizedConversionCastOp convertOp =
1717 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1718 if (config.attachDebugMaterializationKind) {
1719 StringRef kindStr =
1720 kind == MaterializationKind::Source ? "source" : "target";
1721 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1722 }
1724 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1725
1726 // Register the materialization.
1727 unresolvedMaterializations[convertOp] =
1728 UnresolvedMaterializationInfo(converter, kind, originalType);
1729 if (config.allowPatternRollback) {
1730 if (!valuesToMap.empty())
1731 mapping.map(valuesToMap, convertOp.getResults());
1733 std::move(valuesToMap));
1734 } else {
1735 patternMaterializations.insert(convertOp);
1736 }
1737 return convertOp.getResults();
1738}
1739
1741 Value value, const TypeConverter *converter) {
1742 assert(config.allowPatternRollback &&
1743 "this code path is valid only in rollback mode");
1744
1745 // Try to find a replacement value with the same type in the conversion value
1746 // mapping. This includes cached materializations. We try to reuse those
1747 // instead of generating duplicate IR.
1748 ValueVector repl = lookupOrNull(value, value.getType());
1749 if (!repl.empty())
1750 return repl.front();
1751
1752 // Check if the value is dead. No replacement value is needed in that case.
1753 // This is an approximate check that may have false negatives but does not
1754 // require computing and traversing an inverse mapping. (We may end up
1755 // building source materializations that are never used and that fold away.)
1756 if (llvm::all_of(value.getUsers(),
1757 [&](Operation *op) { return replacedOps.contains(op); }) &&
1758 !mapping.isMappedTo(value))
1759 return Value();
1760
1761 // No replacement value was found. Get the latest replacement value
1762 // (regardless of the type) and build a source materialization to the
1763 // original type.
1764 repl = lookupOrNull(value);
1765
1766 // Compute the insertion point of the materialization.
1768 if (repl.empty()) {
1769 // The source materialization has no inputs. Insert it right before the
1770 // value that it is replacing.
1771 ip = computeInsertPoint(value);
1772 } else {
1773 // Compute the "earliest" insertion point at which all values in `repl` are
1774 // defined. It is important to emit the materialization at that location
1775 // because the same materialization may be reused in a different context.
1776 // (That's because materializations are cached in the conversion value
1777 // mapping.) The insertion point of the materialization must be valid for
1778 // all future users that may be created later in the conversion process.
1779 ip = computeInsertPoint(repl);
1780 }
1782 MaterializationKind::Source, ip, value.getLoc(),
1783 /*valuesToMap=*/repl, /*inputs=*/repl,
1784 /*outputTypes=*/value.getType(),
1785 /*originalType=*/Type(), converter,
1786 /*isPureTypeConversion=*/!repl.empty())
1787 .front();
1788 return castValue;
1789}
1790
1791//===----------------------------------------------------------------------===//
1792// Rewriter Notification Hooks
1793//===----------------------------------------------------------------------===//
1794
1796 Operation *op, OpBuilder::InsertPoint previous) {
1797 // If no previous insertion point is provided, the op used to be detached.
1798 bool wasDetached = !previous.isSet();
1799 LLVM_DEBUG({
1800 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1801 << ")";
1802 if (wasDetached)
1803 logger.getOStream() << " (was detached)";
1804 logger.getOStream() << "\n";
1805 });
1806
1807 // In rollback mode, it is easier to misuse the API, so perform extra error
1808 // checking.
1809 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1810 "attempting to insert into a block within a replaced/erased op");
1811
1812 // In "no rollback" mode, the listener is always notified immediately.
1813 if (!config.allowPatternRollback && config.listener)
1814 config.listener->notifyOperationInserted(op, previous);
1815
1816 if (wasDetached) {
1817 // If the op was detached, it is most likely a newly created op. Add it the
1818 // set of newly created ops, so that it will be legalized. If this op is
1819 // not a newly created op, it will be legalized a second time, which is
1820 // inefficient but harmless.
1821 patternNewOps.insert(op);
1822
1823 if (config.allowPatternRollback) {
1824 // TODO: If the same op is inserted multiple times from a detached
1825 // state, the rollback mechanism may erase the same op multiple times.
1826 // This is a bug in the rollback-based dialect conversion driver.
1828 } else {
1829 // In "no rollback" mode, there is an extra data structure for tracking
1830 // erased operations that must be kept up to date.
1831 erasedOps.erase(op);
1832 }
1833 return;
1834 }
1835
1836 // The op was moved from one place to another.
1837 if (config.allowPatternRollback)
1839}
1840
1841/// Given that `fromRange` is about to be replaced with `toRange`, compute
1842/// replacement values with the types of `fromRange`.
1843static SmallVector<Value>
1845 const SmallVector<SmallVector<Value>> &toRange,
1846 const TypeConverter *converter) {
1847 assert(!impl.config.allowPatternRollback &&
1848 "this code path is valid only in 'no rollback' mode");
1849 SmallVector<Value> repls;
1850 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1851 if (from.use_empty()) {
1852 // The replaced value is dead. No replacement value is needed.
1853 repls.push_back(Value());
1854 continue;
1855 }
1856
1857 if (to.empty()) {
1858 // The replaced value is dropped. Materialize a replacement value "out of
1859 // thin air".
1860 Value srcMat = impl.buildUnresolvedMaterialization(
1861 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1862 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1863 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1864 converter)[0];
1865 repls.push_back(srcMat);
1866 continue;
1867 }
1868
1869 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1870 // The replacement value already has the correct type. Use it directly.
1871 repls.push_back(to[0]);
1872 continue;
1873 }
1874
1875 // The replacement value has the wrong type. Build a source materialization
1876 // to the original type.
1877 // TODO: This is a bit inefficient. We should try to reuse existing
1878 // materializations if possible. This would require an extension of the
1879 // `lookupOrDefault` API.
1880 Value srcMat = impl.buildUnresolvedMaterialization(
1881 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1882 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1883 /*originalType=*/Type(), converter)[0];
1884 repls.push_back(srcMat);
1885 }
1886
1887 return repls;
1888}
1889
1891 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1892 assert(newValues.size() == op->getNumResults() &&
1893 "incorrect number of replacement values");
1894 LLVM_DEBUG({
1895 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1896 << ")\n";
1898 // If the user-provided replacement types are different from the
1899 // legalized types, as per the current type converter, print a note.
1900 // In most cases, the replacement types are expected to match the types
1901 // produced by the type converter, so this could indicate a bug in the
1902 // user code.
1903 for (auto [result, repls] :
1904 llvm::zip_equal(op->getResults(), newValues)) {
1905 Type resultType = result.getType();
1906 auto logProlog = [&, repls = repls]() {
1907 logger.startLine() << " Note: Replacing op result of type "
1908 << resultType << " with value(s) of type (";
1909 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1910 logger.getOStream() << v.getType();
1911 });
1912 logger.getOStream() << ")";
1913 };
1914 SmallVector<Type> convertedTypes;
1915 if (failed(currentTypeConverter->convertTypes(resultType,
1916 convertedTypes))) {
1917 logProlog();
1918 logger.getOStream() << ", but the type converter failed to legalize "
1919 "the original type.\n";
1920 continue;
1921 }
1922 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1923 logProlog();
1924 logger.getOStream() << ", but the legalized type(s) is/are (";
1925 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1926 [&](Type t) { logger.getOStream() << t; });
1927 logger.getOStream() << ")\n";
1928 }
1929 }
1930 }
1931 });
1932
1933 if (!config.allowPatternRollback) {
1934 // Pattern rollback is not allowed: materialize all IR changes immediately.
1936 *this, op->getResults(), newValues, currentTypeConverter);
1937 // Update internal data structures, so that there are no dangling pointers
1938 // to erased IR.
1939 op->walk([&](Operation *op) {
1940 erasedOps.insert(op);
1941 ignoredOps.remove(op);
1942 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1943 unresolvedMaterializations.erase(castOp);
1944 patternMaterializations.erase(castOp);
1945 }
1946 // The original op will be erased, so remove it from the set of
1947 // unlegalized ops.
1948 if (config.unlegalizedOps)
1949 config.unlegalizedOps->erase(op);
1950 });
1951 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1952 // Replace the op with the replacement values and notify the listener.
1953 notifyingRewriter.replaceOp(op, repls);
1954 return;
1955 }
1956
1957 assert(!ignoredOps.contains(op) && "operation was already replaced");
1958#ifndef NDEBUG
1959 for (Value v : op->getResults())
1960 assert(!replacedValues.contains(v) &&
1961 "attempting to replace a value that was already replaced");
1962#endif // NDEBUG
1963
1964 // Check if replaced op is an unresolved materialization, i.e., an
1965 // unrealized_conversion_cast op that was created by the conversion driver.
1966 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1967 // Make sure that the user does not mess with unresolved materializations
1968 // that were inserted by the conversion driver. We keep track of these
1969 // ops in internal data structures.
1970 assert(!unresolvedMaterializations.contains(castOp) &&
1971 "attempting to replace/erase an unresolved materialization");
1972 }
1973
1974 // Create mappings for each of the new result values.
1975 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1976 mapping.map(static_cast<Value>(result), std::move(repl));
1977
1979 // Mark this operation and all nested ops as replaced.
1980 op->walk([&](Operation *op) { replacedOps.insert(op); });
1981}
1982
1984 Value from, ValueRange to, const TypeConverter *converter,
1985 function_ref<bool(OpOperand &)> functor) {
1986 LLVM_DEBUG({
1987 logger.startLine() << "** Replace Value : '" << from << "'";
1988 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1989 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1990 logger.getOStream() << " (in region of '" << parentOp->getName()
1991 << "' (" << parentOp << ")";
1992 } else {
1993 logger.getOStream() << " (unlinked block)";
1994 }
1995 }
1996 if (functor) {
1997 logger.getOStream() << ", conditional replacement";
1998 }
1999 });
2000
2001 if (!config.allowPatternRollback) {
2002 SmallVector<Value> toConv = llvm::to_vector(to);
2003 SmallVector<Value> repls =
2004 getReplacementValues(*this, from, {toConv}, converter);
2005 IRRewriter r(from.getContext());
2006 Value repl = repls.front();
2007 if (!repl)
2008 return;
2009
2010 performReplaceValue(r, from, repl, functor);
2011 return;
2012 }
2013
2014#ifndef NDEBUG
2015 // Make sure that a value is not replaced multiple times. In rollback mode,
2016 // `replaceAllUsesWith` replaces not only all current uses of the given value,
2017 // but also all future uses that may be introduced by future pattern
2018 // applications. Therefore, it does not make sense to call
2019 // `replaceAllUsesWith` multiple times with the same value. Doing so would
2020 // overwrite the mapping and mess with the internal state of the dialect
2021 // conversion driver.
2022 assert(!replacedValues.contains(from) &&
2023 "attempting to replace a value that was already replaced");
2024 assert(!wasOpReplaced(from.getDefiningOp()) &&
2025 "attempting to replace a op result that was already replaced");
2026 replacedValues.insert(from);
2027#endif // NDEBUG
2028
2029 if (functor)
2030 llvm::report_fatal_error(
2031 "conditional value replacement is not supported in rollback mode");
2032 mapping.map(from, to);
2033 appendRewrite<ReplaceValueRewrite>(from, converter);
2034}
2035
2037 if (!config.allowPatternRollback) {
2038 // Pattern rollback is not allowed: materialize all IR changes immediately.
2039 // Update internal data structures, so that there are no dangling pointers
2040 // to erased IR.
2041 block->walk([&](Operation *op) {
2042 erasedOps.insert(op);
2043 ignoredOps.remove(op);
2044 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2045 unresolvedMaterializations.erase(castOp);
2046 patternMaterializations.erase(castOp);
2047 }
2048 // The original op will be erased, so remove it from the set of
2049 // unlegalized ops.
2050 if (config.unlegalizedOps)
2051 config.unlegalizedOps->erase(op);
2052 });
2053 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2054 // Erase the block and notify the listener.
2055 notifyingRewriter.eraseBlock(block);
2056 return;
2057 }
2058
2059 assert(!wasOpReplaced(block->getParentOp()) &&
2060 "attempting to erase a block within a replaced/erased op");
2062
2063 // Unlink the block from its parent region. The block is kept in the rewrite
2064 // object and will be actually destroyed when rewrites are applied. This
2065 // allows us to keep the operations in the block live and undo the removal by
2066 // re-inserting the block.
2067 block->getParent()->getBlocks().remove(block);
2068
2069 // Mark all nested ops as erased.
2070 block->walk([&](Operation *op) { replacedOps.insert(op); });
2071}
2072
2074 Block *block, Region *previous, Region::iterator previousIt) {
2075 // If no previous insertion point is provided, the block used to be detached.
2076 bool wasDetached = !previous;
2077 Operation *newParentOp = block->getParentOp();
2078 LLVM_DEBUG(
2079 {
2080 Operation *parent = newParentOp;
2081 if (parent) {
2082 logger.startLine() << "** Insert Block into : '" << parent->getName()
2083 << "' (" << parent << ")";
2084 } else {
2085 logger.startLine()
2086 << "** Insert Block into detached Region (nullptr parent op)";
2087 }
2088 if (wasDetached)
2089 logger.getOStream() << " (was detached)";
2090 logger.getOStream() << "\n";
2091 });
2092
2093 // In rollback mode, it is easier to misuse the API, so perform extra error
2094 // checking.
2095 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2096 "attempting to insert into a region within a replaced/erased op");
2097 (void)newParentOp;
2098
2099 // In "no rollback" mode, the listener is always notified immediately.
2100 if (!config.allowPatternRollback && config.listener)
2101 config.listener->notifyBlockInserted(block, previous, previousIt);
2102
2103 if (wasDetached) {
2104 // If the block was detached, it is most likely a newly created block.
2105 if (config.allowPatternRollback) {
2106 // TODO: If the same block is inserted multiple times from a detached
2107 // state, the rollback mechanism may erase the same block multiple times.
2108 // This is a bug in the rollback-based dialect conversion driver.
2110 } else {
2111 // In "no rollback" mode, there is an extra data structure for tracking
2112 // erased blocks that must be kept up to date.
2113 erasedBlocks.erase(block);
2114 }
2115 return;
2116 }
2117
2118 // The block was moved from one place to another.
2119 if (config.allowPatternRollback)
2120 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2121}
2122
2124 Block *dest,
2125 Block::iterator before) {
2126 appendRewrite<InlineBlockRewrite>(dest, source, before);
2127}
2128
2130 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2131 LLVM_DEBUG({
2133 reasonCallback(diag);
2134 logger.startLine() << "** Failure : " << diag.str() << "\n";
2135 if (config.notifyCallback)
2136 config.notifyCallback(diag);
2137 });
2138}
2139
2140//===----------------------------------------------------------------------===//
2141// ConversionPatternRewriter
2142//===----------------------------------------------------------------------===//
2143
2144ConversionPatternRewriter::ConversionPatternRewriter(
2145 MLIRContext *ctx, const ConversionConfig &config,
2146 OperationConverter &opConverter)
2148 *this, config, opConverter)) {
2149 setListener(impl.get());
2150}
2151
2152ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2153
2154const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2155 return impl->config;
2156}
2157
2158void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2159 assert(op && newOp && "expected non-null op");
2160 replaceOp(op, newOp->getResults());
2161}
2162
2163void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2164 assert(op->getNumResults() == newValues.size() &&
2165 "incorrect # of replacement values");
2166
2167 // If the current insertion point is before the erased operation, we adjust
2168 // the insertion point to be after the operation.
2169 if (getInsertionPoint() == op->getIterator())
2171
2172 SmallVector<SmallVector<Value>> newVals =
2173 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2174 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2175 });
2176 impl->replaceOp(op, std::move(newVals));
2177}
2178
2179void ConversionPatternRewriter::replaceOpWithMultiple(
2180 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2181 assert(op->getNumResults() == newValues.size() &&
2182 "incorrect # of replacement values");
2183
2184 // If the current insertion point is before the erased operation, we adjust
2185 // the insertion point to be after the operation.
2186 if (getInsertionPoint() == op->getIterator())
2188
2189 impl->replaceOp(op, std::move(newValues));
2190}
2191
2192void ConversionPatternRewriter::eraseOp(Operation *op) {
2193 LLVM_DEBUG({
2194 impl->logger.startLine()
2195 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2196 });
2197
2198 // If the current insertion point is before the erased operation, we adjust
2199 // the insertion point to be after the operation.
2200 if (getInsertionPoint() == op->getIterator())
2202
2203 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2204 impl->replaceOp(op, std::move(nullRepls));
2205}
2206
2207void ConversionPatternRewriter::eraseBlock(Block *block) {
2208 impl->eraseBlock(block);
2209}
2210
2211Block *ConversionPatternRewriter::applySignatureConversion(
2212 Block *block, TypeConverter::SignatureConversion &conversion,
2213 const TypeConverter *converter) {
2214 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2215 "attempting to apply a signature conversion to a block within a "
2216 "replaced/erased op");
2217 return impl->applySignatureConversion(block, converter, conversion);
2218}
2219
2220FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2221 Region *region, const TypeConverter &converter,
2222 TypeConverter::SignatureConversion *entryConversion) {
2223 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2224 "attempting to apply a signature conversion to a block within a "
2225 "replaced/erased op");
2226 return impl->convertRegionTypes(region, converter, entryConversion);
2227}
2228
2229void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2230 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2231}
2232
2233void ConversionPatternRewriter::replaceUsesWithIf(
2234 Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2235 bool *allUsesReplaced) {
2236 assert(!allUsesReplaced &&
2237 "allUsesReplaced is not supported in a dialect conversion");
2238 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2239}
2240
2241Value ConversionPatternRewriter::getRemappedValue(Value key) {
2242 SmallVector<ValueVector> remappedValues;
2243 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2244 remappedValues)))
2245 return nullptr;
2246 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2247 return remappedValues.front().front();
2248}
2249
2250LogicalResult
2251ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2252 SmallVectorImpl<Value> &results) {
2253 if (keys.empty())
2254 return success();
2255 SmallVector<ValueVector> remapped;
2256 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2257 remapped)))
2258 return failure();
2259 for (const auto &values : remapped) {
2260 assert(values.size() == 1 && "1:N conversion not supported");
2261 results.push_back(values.front());
2262 }
2263 return success();
2264}
2265
2266void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2267 Block::iterator before,
2268 ValueRange argValues) {
2269#ifndef NDEBUG
2270 assert(argValues.size() == source->getNumArguments() &&
2271 "incorrect # of argument replacement values");
2272 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2273 "attempting to inline a block from a replaced/erased op");
2274 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2275 "attempting to inline a block into a replaced/erased op");
2276 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2277 // The source block will be deleted, so it should not have any users (i.e.,
2278 // there should be no predecessors).
2279 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2280 "expected 'source' to have no predecessors");
2281#endif // NDEBUG
2282
2283 // If a listener is attached to the dialect conversion, ops cannot be moved
2284 // to the destination block in bulk ("fast path"). This is because at the time
2285 // the notifications are sent, it is unknown which ops were moved. Instead,
2286 // ops should be moved one-by-one ("slow path"), so that a separate
2287 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2288 // a bit more efficient, so we try to do that when possible.
2289 bool fastPath = !getConfig().listener;
2290
2291 if (fastPath && impl->config.allowPatternRollback)
2292 impl->inlineBlockBefore(source, dest, before);
2293
2294 // Replace all uses of block arguments.
2295 for (auto it : llvm::zip(source->getArguments(), argValues))
2296 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2297
2298 if (fastPath) {
2299 // Move all ops at once.
2300 dest->getOperations().splice(before, source->getOperations());
2301 } else {
2302 // Move op by op.
2303 while (!source->empty())
2304 moveOpBefore(&source->front(), dest, before);
2305 }
2306
2307 // If the current insertion point is within the source block, adjust the
2308 // insertion point to the destination block.
2309 if (getInsertionBlock() == source)
2310 setInsertionPoint(dest, getInsertionPoint());
2311
2312 // Erase the source block.
2313 eraseBlock(source);
2314}
2315
2316void ConversionPatternRewriter::startOpModification(Operation *op) {
2317 if (!impl->config.allowPatternRollback) {
2318 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2320 return;
2321 }
2322 assert(!impl->wasOpReplaced(op) &&
2323 "attempting to modify a replaced/erased op");
2324#ifndef NDEBUG
2325 impl->pendingRootUpdates.insert(op);
2326#endif
2327 impl->appendRewrite<ModifyOperationRewrite>(op);
2328}
2329
2330void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2331 impl->patternModifiedOps.insert(op);
2332 if (!impl->config.allowPatternRollback) {
2334 if (getConfig().listener)
2335 getConfig().listener->notifyOperationModified(op);
2336 return;
2337 }
2338
2339 // There is nothing to do here, we only need to track the operation at the
2340 // start of the update.
2341#ifndef NDEBUG
2342 assert(!impl->wasOpReplaced(op) &&
2343 "attempting to modify a replaced/erased op");
2344 assert(impl->pendingRootUpdates.erase(op) &&
2345 "operation did not have a pending in-place update");
2346#endif
2347}
2348
2349void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2350 if (!impl->config.allowPatternRollback) {
2352 return;
2353 }
2354#ifndef NDEBUG
2355 assert(impl->pendingRootUpdates.erase(op) &&
2356 "operation did not have a pending in-place update");
2357#endif
2358 // Erase the last update for this operation.
2359 auto it = llvm::find_if(
2360 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2361 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2362 return modifyRewrite && modifyRewrite->getOperation() == op;
2363 });
2364 assert(it != impl->rewrites.rend() && "no root update started on op");
2365 (*it)->rollback();
2366 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2367 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2368}
2369
2370detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2371 return *impl;
2372}
2373
2374//===----------------------------------------------------------------------===//
2375// ConversionPattern
2376//===----------------------------------------------------------------------===//
2377
2378FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2379 ArrayRef<ValueRange> operands) const {
2380 SmallVector<Value> oneToOneOperands;
2381 oneToOneOperands.reserve(operands.size());
2382 for (ValueRange operand : operands) {
2383 if (operand.size() != 1)
2384 return failure();
2385
2386 oneToOneOperands.push_back(operand.front());
2387 }
2388 return std::move(oneToOneOperands);
2389}
2390
2391LogicalResult
2392ConversionPattern::matchAndRewrite(Operation *op,
2393 PatternRewriter &rewriter) const {
2394 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2395 auto &rewriterImpl = dialectRewriter.getImpl();
2396
2397 // Track the current conversion pattern type converter in the rewriter.
2398 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2399 getTypeConverter());
2400
2401 // Remap the operands of the operation.
2402 SmallVector<ValueVector> remapped;
2403 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2404 op->getOperands(), remapped))) {
2405 return failure();
2406 }
2407 SmallVector<ValueRange> remappedAsRange =
2408 llvm::to_vector_of<ValueRange>(remapped);
2409 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2410}
2411
2412//===----------------------------------------------------------------------===//
2413// OperationLegalizer
2414//===----------------------------------------------------------------------===//
2415
2416namespace {
2417/// A set of rewrite patterns that can be used to legalize a given operation.
2418using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2419
2420/// This class defines a recursive operation legalizer.
2421class OperationLegalizer {
2422public:
2423 using LegalizationAction = ConversionTarget::LegalizationAction;
2424
2425 OperationLegalizer(ConversionPatternRewriter &rewriter,
2426 const ConversionTarget &targetInfo,
2427 const FrozenRewritePatternSet &patterns);
2428
2429 /// Returns true if the given operation is known to be illegal on the target.
2430 bool isIllegal(Operation *op) const;
2431
2432 /// Attempt to legalize the given operation. Returns success if the operation
2433 /// was legalized, failure otherwise.
2434 LogicalResult legalize(Operation *op);
2435
2436 /// Returns the conversion target in use by the legalizer.
2437 const ConversionTarget &getTarget() { return target; }
2438
2439private:
2440 /// Attempt to legalize the given operation by folding it.
2441 LogicalResult legalizeWithFold(Operation *op);
2442
2443 /// Attempt to legalize the given operation by applying a pattern. Returns
2444 /// success if the operation was legalized, failure otherwise.
2445 LogicalResult legalizeWithPattern(Operation *op);
2446
2447 /// Return true if the given pattern may be applied to the given operation,
2448 /// false otherwise.
2449 bool canApplyPattern(Operation *op, const Pattern &pattern);
2450
2451 /// Legalize the resultant IR after successfully applying the given pattern.
2452 LogicalResult
2453 legalizePatternResult(Operation *op, const Pattern &pattern,
2454 const RewriterState &curState,
2455 const SetVector<Operation *> &newOps,
2456 const SetVector<Operation *> &modifiedOps);
2457
2458 LogicalResult
2459 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2460 LogicalResult
2461 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2462
2463 //===--------------------------------------------------------------------===//
2464 // Cost Model
2465 //===--------------------------------------------------------------------===//
2466
2467 /// Build an optimistic legalization graph given the provided patterns. This
2468 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2469 /// patterns for operations that are not directly legal, but may be
2470 /// transitively legal for the current target given the provided patterns.
2471 void buildLegalizationGraph(
2472 LegalizationPatterns &anyOpLegalizerPatterns,
2474
2475 /// Compute the benefit of each node within the computed legalization graph.
2476 /// This orders the patterns within 'legalizerPatterns' based upon two
2477 /// criteria:
2478 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2479 /// represent the more direct mapping to the target.
2480 /// 2) When comparing patterns with the same legalization depth, prefer the
2481 /// pattern with the highest PatternBenefit. This allows for users to
2482 /// prefer specific legalizations over others.
2483 void computeLegalizationGraphBenefit(
2484 LegalizationPatterns &anyOpLegalizerPatterns,
2486
2487 /// Compute the legalization depth when legalizing an operation of the given
2488 /// type.
2489 unsigned computeOpLegalizationDepth(
2490 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2492
2493 /// Apply the conversion cost model to the given set of patterns, and return
2494 /// the smallest legalization depth of any of the patterns. See
2495 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2496 unsigned applyCostModelToPatterns(
2497 LegalizationPatterns &patterns,
2498 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2500
2501 /// The current set of patterns that have been applied.
2502 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2503
2504 /// The rewriter to use when converting operations.
2505 ConversionPatternRewriter &rewriter;
2506
2507 /// The legalization information provided by the target.
2508 const ConversionTarget &target;
2509
2510 /// The pattern applicator to use for conversions.
2511 PatternApplicator applicator;
2512};
2513} // namespace
2514
2515OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2516 const ConversionTarget &targetInfo,
2517 const FrozenRewritePatternSet &patterns)
2518 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2519 // The set of patterns that can be applied to illegal operations to transform
2520 // them into legal ones.
2522 LegalizationPatterns anyOpLegalizerPatterns;
2523
2524 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2525 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2526}
2527
2528bool OperationLegalizer::isIllegal(Operation *op) const {
2529 return target.isIllegal(op);
2530}
2531
2532LogicalResult OperationLegalizer::legalize(Operation *op) {
2533#ifndef NDEBUG
2534 const char *logLineComment =
2535 "//===-------------------------------------------===//\n";
2536
2537 auto &logger = rewriter.getImpl().logger;
2538#endif
2539
2540 // Check to see if the operation is ignored and doesn't need to be converted.
2541 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2542
2543 LLVM_DEBUG({
2544 logger.getOStream() << "\n";
2545 logger.startLine() << logLineComment;
2546 logger.startLine() << "Legalizing operation : ";
2547 // Do not print the operation name if the operation is ignored. Ignored ops
2548 // may have been erased and should not be accessed. The pointer can be
2549 // printed safely.
2550 if (!isIgnored)
2551 logger.getOStream() << "'" << op->getName() << "' ";
2552 logger.getOStream() << "(" << op << ") {\n";
2553 logger.indent();
2554
2555 // If the operation has no regions, just print it here.
2556 if (!isIgnored && op->getNumRegions() == 0) {
2557 logger.startLine() << OpWithFlags(op,
2558 OpPrintingFlags().printGenericOpForm())
2559 << "\n";
2560 }
2561 });
2562
2563 if (isIgnored) {
2564 LLVM_DEBUG({
2565 logSuccess(logger, "operation marked 'ignored' during conversion");
2566 logger.startLine() << logLineComment;
2567 });
2568 return success();
2569 }
2570
2571 // Check if this operation is legal on the target.
2572 if (auto legalityInfo = target.isLegal(op)) {
2573 LLVM_DEBUG({
2574 logSuccess(
2575 logger, "operation marked legal by the target{0}",
2576 legalityInfo->isRecursivelyLegal
2577 ? "; NOTE: operation is recursively legal; skipping internals"
2578 : "");
2579 logger.startLine() << logLineComment;
2580 });
2581
2582 // If this operation is recursively legal, mark its children as ignored so
2583 // that we don't consider them for legalization.
2584 if (legalityInfo->isRecursivelyLegal) {
2585 op->walk([&](Operation *nested) {
2586 if (op != nested)
2587 rewriter.getImpl().ignoredOps.insert(nested);
2588 });
2589 }
2590
2591 return success();
2592 }
2593
2594 // If the operation is not legal, try to fold it in-place if the folding mode
2595 // is 'BeforePatterns'. 'Never' will skip this.
2596 const ConversionConfig &config = rewriter.getConfig();
2597 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2598 if (succeeded(legalizeWithFold(op))) {
2599 LLVM_DEBUG({
2600 logSuccess(logger, "operation was folded");
2601 logger.startLine() << logLineComment;
2602 });
2603 return success();
2604 }
2605 }
2606
2607 // Otherwise, we need to apply a legalization pattern to this operation.
2608 if (succeeded(legalizeWithPattern(op))) {
2609 LLVM_DEBUG({
2610 logSuccess(logger, "");
2611 logger.startLine() << logLineComment;
2612 });
2613 return success();
2614 }
2615
2616 // If the operation can't be legalized via patterns, try to fold it in-place
2617 // if the folding mode is 'AfterPatterns'.
2618 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2619 if (succeeded(legalizeWithFold(op))) {
2620 LLVM_DEBUG({
2621 logSuccess(logger, "operation was folded");
2622 logger.startLine() << logLineComment;
2623 });
2624 return success();
2625 }
2626 }
2627
2628 LLVM_DEBUG({
2629 logFailure(logger, "no matched legalization pattern");
2630 logger.startLine() << logLineComment;
2631 });
2632 return failure();
2633}
2634
2635/// Helper function that moves and returns the given object. Also resets the
2636/// original object, so that it is in a valid, empty state again.
2637template <typename T>
2638static T moveAndReset(T &obj) {
2639 T result = std::move(obj);
2640 obj = T();
2641 return result;
2642}
2643
2644LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2645 auto &rewriterImpl = rewriter.getImpl();
2646 LLVM_DEBUG({
2647 rewriterImpl.logger.startLine() << "* Fold {\n";
2648 rewriterImpl.logger.indent();
2649 });
2650
2651 // Clear pattern state, so that the next pattern application starts with a
2652 // clean slate. (The op/block sets are populated by listener notifications.)
2653 llvm::scope_exit cleanup([&]() {
2654 rewriterImpl.patternNewOps.clear();
2655 rewriterImpl.patternModifiedOps.clear();
2656 });
2657
2658 // Upon failure, undo all changes made by the folder.
2659 RewriterState curState = rewriterImpl.getCurrentState();
2660
2661 // Try to fold the operation.
2662 StringRef opName = op->getName().getStringRef();
2663 SmallVector<Value, 2> replacementValues;
2664 SmallVector<Operation *, 2> newOps;
2665 rewriter.setInsertionPoint(op);
2666 rewriter.startOpModification(op);
2667 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2668 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2669 rewriter.cancelOpModification(op);
2670 return failure();
2671 }
2672 rewriter.finalizeOpModification(op);
2673
2674 // An empty list of replacement values indicates that the fold was in-place.
2675 // As the operation changed, a new legalization needs to be attempted.
2676 if (replacementValues.empty())
2677 return legalize(op);
2678
2679 // Insert a replacement for 'op' with the folded replacement values.
2680 rewriter.replaceOp(op, replacementValues);
2681
2682 // Recursively legalize any new constant operations.
2683 for (Operation *newOp : newOps) {
2684 if (failed(legalize(newOp))) {
2685 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2686 "failed to legalize generated constant '{0}'",
2687 newOp->getName()));
2688 if (!rewriter.getConfig().allowPatternRollback) {
2689 // Rolling back a folder is like rolling back a pattern.
2690 llvm::report_fatal_error(
2691 "op '" + opName +
2692 "' folder rollback of IR modifications requested");
2693 }
2694 rewriterImpl.resetState(
2695 curState, std::string(op->getName().getStringRef()) + " folder");
2696 return failure();
2697 }
2698 }
2699
2700 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2701 return success();
2702}
2703
2704/// Report a fatal error indicating that newly produced or modified IR could
2705/// not be legalized.
2706static void
2708 const SetVector<Operation *> &newOps,
2709 const SetVector<Operation *> &modifiedOps) {
2710 auto newOpNames = llvm::map_range(
2711 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2712 auto modifiedOpNames = llvm::map_range(
2713 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2714 llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2715 "' produced IR that could not be legalized. " +
2716 "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
2717 "modified ops: {" +
2718 llvm::join(modifiedOpNames, ", ") + "}");
2719}
2720
2721LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2722 auto &rewriterImpl = rewriter.getImpl();
2723 const ConversionConfig &config = rewriter.getConfig();
2724
2725#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2726 Operation *checkOp;
2727 std::optional<OperationFingerPrint> topLevelFingerPrint;
2728 if (!rewriterImpl.config.allowPatternRollback) {
2729 // The op may be getting erased, so we have to check the parent op.
2730 // (In rare cases, a pattern may even erase the parent op, which will cause
2731 // a crash here. Expensive checks are "best effort".) Skip the check if the
2732 // op does not have a parent op.
2733 if ((checkOp = op->getParentOp())) {
2734 if (!op->getContext()->isMultithreadingEnabled()) {
2735 topLevelFingerPrint = OperationFingerPrint(checkOp);
2736 } else {
2737 // Another thread may be modifying a sibling operation. Therefore, the
2738 // fingerprinting mechanism of the parent op works only in
2739 // single-threaded mode.
2740 LLVM_DEBUG({
2741 rewriterImpl.logger.startLine()
2742 << "WARNING: Multi-threadeding is enabled. Some dialect "
2743 "conversion expensive checks are skipped in multithreading "
2744 "mode!\n";
2745 });
2746 }
2747 }
2748 }
2749#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2750
2751 // Functor that returns if the given pattern may be applied.
2752 auto canApply = [&](const Pattern &pattern) {
2753 bool canApply = canApplyPattern(op, pattern);
2754 if (canApply && config.listener)
2755 config.listener->notifyPatternBegin(pattern, op);
2756 return canApply;
2757 };
2758
2759 // Functor that cleans up the rewriter state after a pattern failed to match.
2760 RewriterState curState = rewriterImpl.getCurrentState();
2761 auto onFailure = [&](const Pattern &pattern) {
2762 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2763 if (!rewriterImpl.config.allowPatternRollback) {
2764 // Erase all unresolved materializations.
2765 for (auto op : rewriterImpl.patternMaterializations) {
2766 rewriterImpl.unresolvedMaterializations.erase(op);
2767 op.erase();
2768 }
2769 rewriterImpl.patternMaterializations.clear();
2770#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2771 // Expensive pattern check that can detect API violations.
2772 if (checkOp && topLevelFingerPrint) {
2773 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2774 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2775 llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2776 "' returned failure but IR did change");
2777 }
2778#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2779 }
2780 rewriterImpl.patternNewOps.clear();
2781 rewriterImpl.patternModifiedOps.clear();
2782 LLVM_DEBUG({
2783 logFailure(rewriterImpl.logger, "pattern failed to match");
2784 if (rewriterImpl.config.notifyCallback) {
2785 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2786 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2787 << "\" on op:\n"
2788 << *op;
2789 rewriterImpl.config.notifyCallback(diag);
2790 }
2791 });
2792 if (config.listener)
2793 config.listener->notifyPatternEnd(pattern, failure());
2794 rewriterImpl.resetState(curState, pattern.getDebugName());
2795 appliedPatterns.erase(&pattern);
2796 };
2797
2798 // Functor that performs additional legalization when a pattern is
2799 // successfully applied.
2800 auto onSuccess = [&](const Pattern &pattern) {
2801 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2802 if (!rewriterImpl.config.allowPatternRollback) {
2803 // Eagerly erase unused materializations.
2804 for (auto op : rewriterImpl.patternMaterializations) {
2805 if (op->use_empty()) {
2806 rewriterImpl.unresolvedMaterializations.erase(op);
2807 op.erase();
2808 }
2809 }
2810 rewriterImpl.patternMaterializations.clear();
2811 }
2812 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2813 SetVector<Operation *> modifiedOps =
2814 moveAndReset(rewriterImpl.patternModifiedOps);
2815 auto result =
2816 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2817 appliedPatterns.erase(&pattern);
2818 if (failed(result)) {
2819 if (!rewriterImpl.config.allowPatternRollback)
2820 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2821 rewriterImpl.resetState(curState, pattern.getDebugName());
2822 }
2823 if (config.listener)
2824 config.listener->notifyPatternEnd(pattern, result);
2825 return result;
2826 };
2827
2828 // Try to match and rewrite a pattern on this operation.
2829 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2830 onSuccess);
2831}
2832
2833bool OperationLegalizer::canApplyPattern(Operation *op,
2834 const Pattern &pattern) {
2835 LLVM_DEBUG({
2836 auto &os = rewriter.getImpl().logger;
2837 os.getOStream() << "\n";
2838 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2839 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2840 os.getOStream() << ")' {\n";
2841 os.indent();
2842 });
2843
2844 // Ensure that we don't cycle by not allowing the same pattern to be
2845 // applied twice in the same recursion stack if it is not known to be safe.
2846 if (!pattern.hasBoundedRewriteRecursion() &&
2847 !appliedPatterns.insert(&pattern).second) {
2848 LLVM_DEBUG(
2849 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2850 return false;
2851 }
2852 return true;
2853}
2854
2855LogicalResult OperationLegalizer::legalizePatternResult(
2856 Operation *op, const Pattern &pattern, const RewriterState &curState,
2857 const SetVector<Operation *> &newOps,
2858 const SetVector<Operation *> &modifiedOps) {
2859 [[maybe_unused]] auto &impl = rewriter.getImpl();
2860 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2861
2862#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2863 if (impl.config.allowPatternRollback) {
2864 // Check that the root was either replaced or updated in place.
2865 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2866 auto replacedRoot = [&] {
2867 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2868 };
2869 auto updatedRootInPlace = [&] {
2870 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2871 };
2872 if (!replacedRoot() && !updatedRootInPlace())
2873 llvm::report_fatal_error("expected pattern to replace the root operation "
2874 "or modify it in place");
2875 }
2876#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2877
2878 // Legalize each of the actions registered during application.
2879 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2880 failed(legalizePatternCreatedOperations(newOps))) {
2881 return failure();
2882 }
2883
2884 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2885 return success();
2886}
2887
2888LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2889 const SetVector<Operation *> &newOps) {
2890 for (Operation *op : newOps) {
2891 if (failed(legalize(op))) {
2892 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2893 "failed to legalize generated operation '{0}'({1})",
2894 op->getName(), op));
2895 return failure();
2896 }
2897 }
2898 return success();
2899}
2900
2901LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2902 const SetVector<Operation *> &modifiedOps) {
2903 for (Operation *op : modifiedOps) {
2904 if (failed(legalize(op))) {
2905 LLVM_DEBUG(
2906 logFailure(rewriter.getImpl().logger,
2907 "failed to legalize operation updated in-place '{0}'",
2908 op->getName()));
2909 return failure();
2910 }
2911 }
2912 return success();
2913}
2914
2915//===----------------------------------------------------------------------===//
2916// Cost Model
2917//===----------------------------------------------------------------------===//
2918
2919void OperationLegalizer::buildLegalizationGraph(
2920 LegalizationPatterns &anyOpLegalizerPatterns,
2922 // A mapping between an operation and a set of operations that can be used to
2923 // generate it.
2925 // A mapping between an operation and any currently invalid patterns it has.
2927 // A worklist of patterns to consider for legality.
2928 SetVector<const Pattern *> patternWorklist;
2929
2930 // Build the mapping from operations to the parent ops that may generate them.
2931 applicator.walkAllPatterns([&](const Pattern &pattern) {
2932 std::optional<OperationName> root = pattern.getRootKind();
2933
2934 // If the pattern has no specific root, we can't analyze the relationship
2935 // between the root op and generated operations. Given that, add all such
2936 // patterns to the legalization set.
2937 if (!root) {
2938 anyOpLegalizerPatterns.push_back(&pattern);
2939 return;
2940 }
2941
2942 // Skip operations that are always known to be legal.
2943 if (target.getOpAction(*root) == LegalizationAction::Legal)
2944 return;
2945
2946 // Add this pattern to the invalid set for the root op and record this root
2947 // as a parent for any generated operations.
2948 invalidPatterns[*root].insert(&pattern);
2949 for (auto op : pattern.getGeneratedOps())
2950 parentOps[op].insert(*root);
2951
2952 // Add this pattern to the worklist.
2953 patternWorklist.insert(&pattern);
2954 });
2955
2956 // If there are any patterns that don't have a specific root kind, we can't
2957 // make direct assumptions about what operations will never be legalized.
2958 // Note: Technically we could, but it would require an analysis that may
2959 // recurse into itself. It would be better to perform this kind of filtering
2960 // at a higher level than here anyways.
2961 if (!anyOpLegalizerPatterns.empty()) {
2962 for (const Pattern *pattern : patternWorklist)
2963 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2964 return;
2965 }
2966
2967 while (!patternWorklist.empty()) {
2968 auto *pattern = patternWorklist.pop_back_val();
2969
2970 // Check to see if any of the generated operations are invalid.
2971 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2972 std::optional<LegalizationAction> action = target.getOpAction(op);
2973 return !legalizerPatterns.count(op) &&
2974 (!action || action == LegalizationAction::Illegal);
2975 }))
2976 continue;
2977
2978 // Otherwise, if all of the generated operation are valid, this op is now
2979 // legal so add all of the child patterns to the worklist.
2980 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2981 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2982
2983 // Add any invalid patterns of the parent operations to see if they have now
2984 // become legal.
2985 for (auto op : parentOps[*pattern->getRootKind()])
2986 patternWorklist.set_union(invalidPatterns[op]);
2987 }
2988}
2989
2990void OperationLegalizer::computeLegalizationGraphBenefit(
2991 LegalizationPatterns &anyOpLegalizerPatterns,
2993 // The smallest pattern depth, when legalizing an operation.
2994 DenseMap<OperationName, unsigned> minOpPatternDepth;
2995
2996 // For each operation that is transitively legal, compute a cost for it.
2997 for (auto &opIt : legalizerPatterns)
2998 if (!minOpPatternDepth.count(opIt.first))
2999 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3000 legalizerPatterns);
3001
3002 // Apply the cost model to the patterns that can match any operation. Those
3003 // with a specific operation type are already resolved when computing the op
3004 // legalization depth.
3005 if (!anyOpLegalizerPatterns.empty())
3006 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3007 legalizerPatterns);
3008
3009 // Apply a cost model to the pattern applicator. We order patterns first by
3010 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3011 // decreasing benefit.
3012 applicator.applyCostModel([&](const Pattern &pattern) {
3013 ArrayRef<const Pattern *> orderedPatternList;
3014 if (std::optional<OperationName> rootName = pattern.getRootKind())
3015 orderedPatternList = legalizerPatterns[*rootName];
3016 else
3017 orderedPatternList = anyOpLegalizerPatterns;
3018
3019 // If the pattern is not found, then it was removed and cannot be matched.
3020 auto *it = llvm::find(orderedPatternList, &pattern);
3021 if (it == orderedPatternList.end())
3023
3024 // Patterns found earlier in the list have higher benefit.
3025 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3026 });
3027}
3028
3029unsigned OperationLegalizer::computeOpLegalizationDepth(
3030 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3032 // Check for existing depth.
3033 auto depthIt = minOpPatternDepth.find(op);
3034 if (depthIt != minOpPatternDepth.end())
3035 return depthIt->second;
3036
3037 // If a mapping for this operation does not exist, then this operation
3038 // is always legal. Return 0 as the depth for a directly legal operation.
3039 auto opPatternsIt = legalizerPatterns.find(op);
3040 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3041 return 0u;
3042
3043 // Record this initial depth in case we encounter this op again when
3044 // recursively computing the depth.
3045 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3046
3047 // Apply the cost model to the operation patterns, and update the minimum
3048 // depth.
3049 unsigned minDepth = applyCostModelToPatterns(
3050 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3051 minOpPatternDepth[op] = minDepth;
3052 return minDepth;
3053}
3054
3055unsigned OperationLegalizer::applyCostModelToPatterns(
3056 LegalizationPatterns &patterns,
3057 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3059 unsigned minDepth = std::numeric_limits<unsigned>::max();
3060
3061 // Compute the depth for each pattern within the set.
3062 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3063 patternsByDepth.reserve(patterns.size());
3064 for (const Pattern *pattern : patterns) {
3065 unsigned depth = 1;
3066 for (auto generatedOp : pattern->getGeneratedOps()) {
3067 unsigned generatedOpDepth = computeOpLegalizationDepth(
3068 generatedOp, minOpPatternDepth, legalizerPatterns);
3069 depth = std::max(depth, generatedOpDepth + 1);
3070 }
3071 patternsByDepth.emplace_back(pattern, depth);
3072
3073 // Update the minimum depth of the pattern list.
3074 minDepth = std::min(minDepth, depth);
3075 }
3076
3077 // If the operation only has one legalization pattern, there is no need to
3078 // sort them.
3079 if (patternsByDepth.size() == 1)
3080 return minDepth;
3081
3082 // Sort the patterns by those likely to be the most beneficial.
3083 llvm::stable_sort(patternsByDepth,
3084 [](const std::pair<const Pattern *, unsigned> &lhs,
3085 const std::pair<const Pattern *, unsigned> &rhs) {
3086 // First sort by the smaller pattern legalization
3087 // depth.
3088 if (lhs.second != rhs.second)
3089 return lhs.second < rhs.second;
3090
3091 // Then sort by the larger pattern benefit.
3092 auto lhsBenefit = lhs.first->getBenefit();
3093 auto rhsBenefit = rhs.first->getBenefit();
3094 return lhsBenefit > rhsBenefit;
3095 });
3096
3097 // Update the legalization pattern to use the new sorted list.
3098 patterns.clear();
3099 for (auto &patternIt : patternsByDepth)
3100 patterns.push_back(patternIt.first);
3101 return minDepth;
3102}
3103
3104//===----------------------------------------------------------------------===//
3105// Reconcile Unrealized Casts
3106//===----------------------------------------------------------------------===//
3107
3108/// Try to reconcile all given UnrealizedConversionCastOps and store the
3109/// left-over ops in `remainingCastOps` (if provided). See documentation in
3110/// DialectConversion.h for more details.
3111/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3112/// algorithm may visit an operand (or user) which is a cast op, but will not
3113/// try to reconcile it if not in the filtered set.
3114template <typename RangeT>
3116 RangeT castOps,
3117 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3119 // A worklist of cast ops to process.
3120 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3121
3122 // Helper function that return the unrealized_conversion_cast op that
3123 // defines all inputs of the given op (in the same order). Return "nullptr"
3124 // if there is no such op.
3125 auto getInputCast =
3126 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3127 if (castOp.getInputs().empty())
3128 return {};
3129 auto inputCastOp =
3130 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3131 if (!inputCastOp)
3132 return {};
3133 if (inputCastOp.getOutputs() != castOp.getInputs())
3134 return {};
3135 return inputCastOp;
3136 };
3137
3138 // Process ops in the worklist bottom-to-top.
3139 while (!worklist.empty()) {
3140 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3141
3142 // Traverse the chain of input cast ops to see if an op with the same
3143 // input types can be found.
3144 UnrealizedConversionCastOp nextCast = castOp;
3145 while (nextCast) {
3146 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3147 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3148 return v.getDefiningOp() == castOp;
3149 })) {
3150 // Ran into a cycle.
3151 break;
3152 }
3153
3154 // Found a cast where the input types match the output types of the
3155 // matched op. We can directly use those inputs.
3156 castOp.replaceAllUsesWith(nextCast.getInputs());
3157 break;
3158 }
3159 nextCast = getInputCast(nextCast);
3160 }
3161 }
3162
3163 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3164 // used by an op that is not a cast op.
3165 DenseSet<Operation *> liveOps;
3166
3167 // Helper function that marks the given op and transitively reachable input
3168 // cast ops as alive.
3169 auto markOpLive = [&](Operation *rootOp) {
3170 SmallVector<Operation *> worklist;
3171 worklist.push_back(rootOp);
3172 while (!worklist.empty()) {
3173 Operation *op = worklist.pop_back_val();
3174 if (liveOps.insert(op).second) {
3175 // Successfully inserted: process reachable input cast ops.
3176 for (Value v : op->getOperands())
3177 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3178 if (isCastOpOfInterestFn(castOp))
3179 worklist.push_back(castOp);
3180 }
3181 }
3182 };
3183
3184 // Find all alive cast ops.
3185 for (UnrealizedConversionCastOp op : castOps) {
3186 // The op may have been marked live already as being an operand of another
3187 // live cast op.
3188 if (liveOps.contains(op.getOperation()))
3189 continue;
3190 // If any of the users is not a cast op, mark the current op (and its
3191 // input ops) as live.
3192 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3193 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3194 return !castOp || !isCastOpOfInterestFn(castOp);
3195 }))
3196 markOpLive(op);
3197 }
3198
3199 // Erase all dead cast ops.
3200 for (UnrealizedConversionCastOp op : castOps) {
3201 if (liveOps.contains(op)) {
3202 // Op is alive and was not erased. Add it to the remaining cast ops.
3203 if (remainingCastOps)
3204 remainingCastOps->push_back(op);
3205 continue;
3206 }
3207
3208 // Op is dead. Erase it.
3209 op->dropAllUses();
3210 op->erase();
3211 }
3212}
3213
3215 ArrayRef<UnrealizedConversionCastOp> castOps,
3216 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3217 // Set of all cast ops for faster lookups.
3218 DenseSet<UnrealizedConversionCastOp> castOpSet;
3219 for (UnrealizedConversionCastOp op : castOps)
3220 castOpSet.insert(op);
3221 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3222}
3223
3225 const DenseSet<UnrealizedConversionCastOp> &castOps,
3226 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3228 llvm::make_range(castOps.begin(), castOps.end()),
3229 [&](UnrealizedConversionCastOp castOp) {
3230 return castOps.contains(castOp);
3231 },
3232 remainingCastOps);
3233}
3234
3235namespace mlir {
3238 &castOps,
3241 castOps.keys(),
3242 [&](UnrealizedConversionCastOp castOp) {
3243 return castOps.contains(castOp);
3244 },
3245 remainingCastOps);
3246}
3247} // namespace mlir
3248
3249//===----------------------------------------------------------------------===//
3250// OperationConverter
3251//===----------------------------------------------------------------------===//
3252
3253namespace mlir {
3254// This class converts operations to a given conversion target via a set of
3255// rewrite patterns. The conversion behaves differently depending on the
3256// conversion mode.
3260 const ConversionConfig &config,
3261 OpConversionMode mode)
3262 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3263 mode(mode) {}
3264
3265 /// Applies the conversion to the given operations (and their nested
3266 /// operations).
3267 LogicalResult applyConversion(ArrayRef<Operation *> ops);
3268
3269 /// Legalizes the given operations (and their nested operations) to the
3270 /// conversion target.
3271 template <typename Fn>
3272 LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3273 bool isRecursiveLegalization = false);
3275 bool isRecursiveLegalization = false) {
3276 return legalizeOperations(
3277 ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
3278 }
3279
3280 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3281 /// conversion is a recursive legalization request, triggered from within a
3282 /// pattern. In that case, do not emit errors because there will be another
3283 /// attempt at legalizing the operation later (via the regular pre-order
3284 /// legalization mechanism).
3285 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3286
3287 const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
3288
3289private:
3290 /// The rewriter to use when converting operations.
3291 ConversionPatternRewriter rewriter;
3292
3293 /// The legalizer to use when converting operations.
3294 OperationLegalizer opLegalizer;
3295
3296 /// The conversion mode to use when legalizing operations.
3297 OpConversionMode mode;
3298};
3299} // namespace mlir
3300
3302 bool isRecursiveLegalization) {
3303 const ConversionConfig &config = rewriter.getConfig();
3304
3305 // Legalize the given operation.
3306 if (failed(opLegalizer.legalize(op))) {
3307 // Handle the case of a failed conversion for each of the different modes.
3308 // Full conversions expect all operations to be converted.
3309 if (mode == OpConversionMode::Full) {
3310 if (!isRecursiveLegalization)
3311 op->emitError() << "failed to legalize operation '" << op->getName()
3312 << "'";
3313 return failure();
3314 }
3315 // Partial conversions allow conversions to fail iff the operation was not
3316 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3317 // set, non-legalizable ops are added to that set.
3318 if (mode == OpConversionMode::Partial) {
3319 if (opLegalizer.isIllegal(op)) {
3320 if (!isRecursiveLegalization)
3321 op->emitError() << "failed to legalize operation '" << op->getName()
3322 << "' that was explicitly marked illegal";
3323 return failure();
3324 }
3325 if (config.unlegalizedOps && !isRecursiveLegalization)
3326 config.unlegalizedOps->insert(op);
3327 }
3328 } else if (mode == OpConversionMode::Analysis) {
3329 // Analysis conversions don't fail if any operations fail to legalize,
3330 // they are only interested in the operations that were successfully
3331 // legalized.
3332 if (config.legalizableOps && !isRecursiveLegalization)
3333 config.legalizableOps->insert(op);
3334 }
3335 return success();
3336}
3337
3338static LogicalResult
3340 UnrealizedConversionCastOp op,
3341 const UnresolvedMaterializationInfo &info) {
3342 assert(!op.use_empty() &&
3343 "expected that dead materializations have already been DCE'd");
3344 Operation::operand_range inputOperands = op.getOperands();
3345
3346 // Try to materialize the conversion.
3347 if (const TypeConverter *converter = info.getConverter()) {
3348 rewriter.setInsertionPoint(op);
3349 SmallVector<Value> newMaterialization;
3350 switch (info.getMaterializationKind()) {
3351 case MaterializationKind::Target:
3352 newMaterialization = converter->materializeTargetConversion(
3353 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3354 info.getOriginalType());
3355 break;
3356 case MaterializationKind::Source:
3357 assert(op->getNumResults() == 1 && "expected single result");
3358 Value sourceMat = converter->materializeSourceConversion(
3359 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3360 if (sourceMat)
3361 newMaterialization.push_back(sourceMat);
3362 break;
3363 }
3364 if (!newMaterialization.empty()) {
3365#ifndef NDEBUG
3366 ValueRange newMaterializationRange(newMaterialization);
3367 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3368 "materialization callback produced value of incorrect type");
3369#endif // NDEBUG
3370 rewriter.replaceOp(op, newMaterialization);
3371 return success();
3372 }
3373 }
3374
3375 InFlightDiagnostic diag = op->emitError()
3376 << "failed to legalize unresolved materialization "
3377 "from ("
3378 << inputOperands.getTypes() << ") to ("
3379 << op.getResultTypes()
3380 << ") that remained live after conversion";
3381 diag.attachNote(op->getUsers().begin()->getLoc())
3382 << "see existing live user here: " << *op->getUsers().begin();
3383 return failure();
3384}
3385
3386template <typename Fn>
3387LogicalResult
3389 bool isRecursiveLegalization) {
3390 const ConversionTarget &target = opLegalizer.getTarget();
3391
3392 // Compute the set of operations and blocks to convert.
3393 SmallVector<Operation *> toConvert;
3394 for (Operation *op : ops) {
3396 [&](Operation *op) {
3397 toConvert.push_back(op);
3398 // Don't check this operation's children for conversion if the
3399 // operation is recursively legal.
3400 auto legalityInfo = target.isLegal(op);
3401 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3402 return WalkResult::skip();
3403 return WalkResult::advance();
3404 });
3405 }
3406 for (Operation *op : toConvert) {
3407 if (failed(convert(op, isRecursiveLegalization))) {
3408 // Failed to convert an operation.
3409 onFailure();
3410 return failure();
3411 }
3412 }
3413 return success();
3414}
3415
3416LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3417 return impl->opConverter.legalizeOperations(op,
3418 /*isRecursiveLegalization=*/true);
3419}
3420
3421LogicalResult ConversionPatternRewriter::legalize(Region *r) {
3422 // Fast path: If the region is empty, there is nothing to legalize.
3423 if (r->empty())
3424 return success();
3425
3426 // Gather a list of all operations to legalize. This is done before
3427 // converting the entry block signature because unrealized_conversion_cast
3428 // ops should not be included.
3430 for (Block &b : *r)
3431 for (Operation &op : b)
3432 ops.push_back(&op);
3433
3434 // If the current pattern runs with a type converter, convert the entry block
3435 // signature.
3436 if (const TypeConverter *converter = impl->currentTypeConverter) {
3437 std::optional<TypeConverter::SignatureConversion> conversion =
3438 converter->convertBlockSignature(&r->front());
3439 if (!conversion)
3440 return failure();
3441 applySignatureConversion(&r->front(), *conversion, converter);
3442 }
3443
3444 // Legalize all operations in the region. This includes all nested
3445 // operations.
3446 return impl->opConverter.legalizeOperations(ops,
3447 /*isRecursiveLegalization=*/true);
3448}
3449
3451 // Convert each operation and discard rewrites on failure.
3452 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3453 LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
3454 // Dialect conversion failed.
3455 if (rewriterImpl.config.allowPatternRollback) {
3456 // Rollback is allowed: restore the original IR.
3457 rewriterImpl.undoRewrites();
3458 } else {
3459 // Rollback is not allowed: apply all modifications that have been
3460 // performed so far.
3461 rewriterImpl.applyRewrites();
3462 }
3463 });
3464 if (failed(status))
3465 return failure();
3466
3467 // After a successful conversion, apply rewrites.
3468 rewriterImpl.applyRewrites();
3469
3470 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3471 // dialect conversion frameworks. (Not the ones that were inserted by
3472 // patterns.)
3474 &materializations = rewriterImpl.unresolvedMaterializations;
3476 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3477
3478 // Drop markers.
3479 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3480 castOp->removeAttr(kPureTypeConversionMarker);
3481
3482 // Try to legalize all unresolved materializations.
3483 if (rewriter.getConfig().buildMaterializations) {
3484 // Use a new rewriter, so the modifications are not tracked for rollback
3485 // purposes etc.
3486 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3487 rewriter.getConfig().listener);
3488 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3489 auto it = materializations.find(castOp);
3490 assert(it != materializations.end() && "inconsistent state");
3491 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3492 it->second)))
3493 return failure();
3494 }
3495 }
3496
3497 return success();
3498}
3499
3500//===----------------------------------------------------------------------===//
3501// Type Conversion
3502//===----------------------------------------------------------------------===//
3503
3504void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3505 ArrayRef<Type> types) {
3506 assert(!types.empty() && "expected valid types");
3507 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3508 addInputs(types);
3509}
3510
3511void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3512 assert(!types.empty() &&
3513 "1->0 type remappings don't need to be added explicitly");
3514 argTypes.append(types.begin(), types.end());
3515}
3516
3517void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3518 unsigned newInputNo,
3519 unsigned newInputCount) {
3520 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3521 assert(newInputCount != 0 && "expected valid input count");
3522 remappedInputs[origInputNo] =
3523 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3524}
3525
3526void TypeConverter::SignatureConversion::remapInput(
3527 unsigned origInputNo, ArrayRef<Value> replacements) {
3528 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3529 remappedInputs[origInputNo] = InputMapping{
3530 origInputNo, /*size=*/0,
3531 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3532}
3533
3534/// Internal implementation of the type conversion.
3535/// This is used with either a Type or a Value as the first argument.
3536/// - we can cache the context-free conversions until the last registered
3537/// context-aware conversion.
3538/// - we can't cache the result of type conversion happening after context-aware
3539/// conversions, because the type converter may return different results for the
3540/// same input type.
3541LogicalResult
3542TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3543 SmallVectorImpl<Type> &results) const {
3544 assert(typeOrValue && "expected non-null type");
3545 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3546 : cast<Type>(typeOrValue);
3547 {
3548 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3549 std::defer_lock);
3551 cacheReadLock.lock();
3552 auto existingIt = cachedDirectConversions.find(t);
3553 if (existingIt != cachedDirectConversions.end()) {
3554 if (existingIt->second)
3555 results.push_back(existingIt->second);
3556 return success(existingIt->second != nullptr);
3557 }
3558 auto multiIt = cachedMultiConversions.find(t);
3559 if (multiIt != cachedMultiConversions.end()) {
3560 results.append(multiIt->second.begin(), multiIt->second.end());
3561 return success();
3562 }
3563 }
3564 // Walk the added converters in reverse order to apply the most recently
3565 // registered first.
3566 size_t currentCount = results.size();
3567
3568 // We can cache the context-free conversions until the last registered
3569 // context-aware conversion. But only if we're processing a Value right now.
3570 auto isCacheable = [&](int index) {
3571 int numberOfConversionsUntilContextAware =
3572 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3573 return index < numberOfConversionsUntilContextAware;
3574 };
3575
3576 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3577 std::defer_lock);
3578
3579 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3580 const ConversionCallbackFn &converter = indexedConverter.value();
3581 std::optional<LogicalResult> result = converter(typeOrValue, results);
3582 if (!result) {
3583 assert(results.size() == currentCount &&
3584 "failed type conversion should not change results");
3585 continue;
3586 }
3587 if (!isCacheable(indexedConverter.index()))
3588 return success();
3590 cacheWriteLock.lock();
3591 if (!succeeded(*result)) {
3592 assert(results.size() == currentCount &&
3593 "failed type conversion should not change results");
3594 cachedDirectConversions.try_emplace(t, nullptr);
3595 return failure();
3596 }
3597 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3598 if (newTypes.size() == 1)
3599 cachedDirectConversions.try_emplace(t, newTypes.front());
3600 else
3601 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3602 return success();
3603 }
3604 return failure();
3605}
3606
3607LogicalResult TypeConverter::convertType(Type t,
3608 SmallVectorImpl<Type> &results) const {
3609 return convertTypeImpl(t, results);
3610}
3611
3612LogicalResult TypeConverter::convertType(Value v,
3613 SmallVectorImpl<Type> &results) const {
3614 return convertTypeImpl(v, results);
3615}
3616
3617Type TypeConverter::convertType(Type t) const {
3618 // Use the multi-type result version to convert the type.
3619 SmallVector<Type, 1> results;
3620 if (failed(convertType(t, results)))
3621 return nullptr;
3622
3623 // Check to ensure that only one type was produced.
3624 return results.size() == 1 ? results.front() : nullptr;
3625}
3626
3627Type TypeConverter::convertType(Value v) const {
3628 // Use the multi-type result version to convert the type.
3629 SmallVector<Type, 1> results;
3630 if (failed(convertType(v, results)))
3631 return nullptr;
3632
3633 // Check to ensure that only one type was produced.
3634 return results.size() == 1 ? results.front() : nullptr;
3635}
3636
3637LogicalResult
3638TypeConverter::convertTypes(TypeRange types,
3639 SmallVectorImpl<Type> &results) const {
3640 for (Type type : types)
3641 if (failed(convertType(type, results)))
3642 return failure();
3643 return success();
3644}
3645
3646LogicalResult
3647TypeConverter::convertTypes(ValueRange values,
3648 SmallVectorImpl<Type> &results) const {
3649 for (Value value : values)
3650 if (failed(convertType(value, results)))
3651 return failure();
3652 return success();
3653}
3654
3655bool TypeConverter::isLegal(Type type) const {
3656 return convertType(type) == type;
3657}
3658
3659bool TypeConverter::isLegal(Value value) const {
3660 return convertType(value) == value.getType();
3661}
3662
3663bool TypeConverter::isLegal(Operation *op) const {
3664 return isLegal(op->getOperands()) && isLegal(op->getResults());
3665}
3666
3667bool TypeConverter::isLegal(Region *region) const {
3668 return llvm::all_of(
3669 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3670}
3671
3672bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3673 if (!isLegal(ty.getInputs()))
3674 return false;
3675 if (!isLegal(ty.getResults()))
3676 return false;
3677 return true;
3678}
3679
3680LogicalResult
3681TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3682 SignatureConversion &result) const {
3683 // Try to convert the given input type.
3684 SmallVector<Type, 1> convertedTypes;
3685 if (failed(convertType(type, convertedTypes)))
3686 return failure();
3687
3688 // If this argument is being dropped, there is nothing left to do.
3689 if (convertedTypes.empty())
3690 return success();
3691
3692 // Otherwise, add the new inputs.
3693 result.addInputs(inputNo, convertedTypes);
3694 return success();
3695}
3696LogicalResult
3697TypeConverter::convertSignatureArgs(TypeRange types,
3698 SignatureConversion &result,
3699 unsigned origInputOffset) const {
3700 for (unsigned i = 0, e = types.size(); i != e; ++i)
3701 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3702 return failure();
3703 return success();
3704}
3705LogicalResult
3706TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3707 SignatureConversion &result) const {
3708 // Try to convert the given input type.
3709 SmallVector<Type, 1> convertedTypes;
3710 if (failed(convertType(value, convertedTypes)))
3711 return failure();
3712
3713 // If this argument is being dropped, there is nothing left to do.
3714 if (convertedTypes.empty())
3715 return success();
3716
3717 // Otherwise, add the new inputs.
3718 result.addInputs(inputNo, convertedTypes);
3719 return success();
3720}
3721LogicalResult
3722TypeConverter::convertSignatureArgs(ValueRange values,
3723 SignatureConversion &result,
3724 unsigned origInputOffset) const {
3725 for (unsigned i = 0, e = values.size(); i != e; ++i)
3726 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3727 return failure();
3728 return success();
3729}
3730
3731Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3732 Location loc, Type resultType,
3733 ValueRange inputs) const {
3734 for (const SourceMaterializationCallbackFn &fn :
3735 llvm::reverse(sourceMaterializations))
3736 if (Value result = fn(builder, resultType, inputs, loc))
3737 return result;
3738 return nullptr;
3739}
3740
3741Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3742 Location loc, Type resultType,
3743 ValueRange inputs,
3744 Type originalType) const {
3745 SmallVector<Value> result = materializeTargetConversion(
3746 builder, loc, TypeRange(resultType), inputs, originalType);
3747 if (result.empty())
3748 return nullptr;
3749 assert(result.size() == 1 && "expected single result");
3750 return result.front();
3751}
3752
3753SmallVector<Value> TypeConverter::materializeTargetConversion(
3754 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3755 Type originalType) const {
3756 for (const TargetMaterializationCallbackFn &fn :
3757 llvm::reverse(targetMaterializations)) {
3758 SmallVector<Value> result =
3759 fn(builder, resultTypes, inputs, loc, originalType);
3760 if (result.empty())
3761 continue;
3762 assert(TypeRange(ValueRange(result)) == resultTypes &&
3763 "callback produced incorrect number of values or values with "
3764 "incorrect types");
3765 return result;
3766 }
3767 return {};
3768}
3769
3770std::optional<TypeConverter::SignatureConversion>
3771TypeConverter::convertBlockSignature(Block *block) const {
3772 SignatureConversion conversion(block->getNumArguments());
3773 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3774 return std::nullopt;
3775 return conversion;
3776}
3777
3778//===----------------------------------------------------------------------===//
3779// Type attribute conversion
3780//===----------------------------------------------------------------------===//
3781TypeConverter::AttributeConversionResult
3782TypeConverter::AttributeConversionResult::result(Attribute attr) {
3783 return AttributeConversionResult(attr, resultTag);
3784}
3785
3786TypeConverter::AttributeConversionResult
3787TypeConverter::AttributeConversionResult::na() {
3788 return AttributeConversionResult(nullptr, naTag);
3789}
3790
3791TypeConverter::AttributeConversionResult
3792TypeConverter::AttributeConversionResult::abort() {
3793 return AttributeConversionResult(nullptr, abortTag);
3794}
3795
3796bool TypeConverter::AttributeConversionResult::hasResult() const {
3797 return impl.getInt() == resultTag;
3798}
3799
3800bool TypeConverter::AttributeConversionResult::isNa() const {
3801 return impl.getInt() == naTag;
3802}
3803
3804bool TypeConverter::AttributeConversionResult::isAbort() const {
3805 return impl.getInt() == abortTag;
3806}
3807
3808Attribute TypeConverter::AttributeConversionResult::getResult() const {
3809 assert(hasResult() && "Cannot get result from N/A or abort");
3810 return impl.getPointer();
3811}
3812
3813std::optional<Attribute>
3814TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3815 for (const TypeAttributeConversionCallbackFn &fn :
3816 llvm::reverse(typeAttributeConversions)) {
3817 AttributeConversionResult res = fn(type, attr);
3818 if (res.hasResult())
3819 return res.getResult();
3820 if (res.isAbort())
3821 return std::nullopt;
3822 }
3823 return std::nullopt;
3824}
3825
3826//===----------------------------------------------------------------------===//
3827// FunctionOpInterfaceSignatureConversion
3828//===----------------------------------------------------------------------===//
3829
3830static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3831 const TypeConverter &typeConverter,
3832 ConversionPatternRewriter &rewriter) {
3833 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3834 if (!type)
3835 return failure();
3836
3837 // Convert the original function types.
3838 TypeConverter::SignatureConversion result(type.getNumInputs());
3839 SmallVector<Type, 1> newResults;
3840 if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3841 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3842 return failure();
3843 if (!funcOp.getFunctionBody().empty())
3844 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3845 &typeConverter);
3846
3847 // Update the function signature in-place.
3848 auto newType = FunctionType::get(rewriter.getContext(),
3849 result.getConvertedTypes(), newResults);
3850
3851 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3852
3853 return success();
3854}
3855
3856/// Create a default conversion pattern that rewrites the type signature of a
3857/// FunctionOpInterface op. This only supports ops which use FunctionType to
3858/// represent their type.
3859namespace {
3860struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3861 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3862 MLIRContext *ctx,
3863 const TypeConverter &converter,
3864 PatternBenefit benefit)
3865 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3866
3867 LogicalResult
3868 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3869 ConversionPatternRewriter &rewriter) const override {
3870 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3871 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3872 }
3873};
3874
3875struct AnyFunctionOpInterfaceSignatureConversion
3876 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3877 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3878
3879 LogicalResult
3880 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3881 ConversionPatternRewriter &rewriter) const override {
3882 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3883 }
3884};
3885} // namespace
3886
3887FailureOr<Operation *>
3888mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3889 const TypeConverter &converter,
3890 ConversionPatternRewriter &rewriter) {
3891 assert(op && "Invalid op");
3892 Location loc = op->getLoc();
3893 if (converter.isLegal(op))
3894 return rewriter.notifyMatchFailure(loc, "op already legal");
3895
3896 OperationState newOp(loc, op->getName());
3897 newOp.addOperands(operands);
3898
3899 SmallVector<Type> newResultTypes;
3900 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3901 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3902
3903 newOp.addTypes(newResultTypes);
3904 newOp.addAttributes(op->getAttrs());
3905 return rewriter.create(newOp);
3906}
3907
3908void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3909 StringRef functionLikeOpName, RewritePatternSet &patterns,
3910 const TypeConverter &converter, PatternBenefit benefit) {
3911 patterns.add<FunctionOpInterfaceSignatureConversion>(
3912 functionLikeOpName, patterns.getContext(), converter, benefit);
3913}
3914
3915void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3916 RewritePatternSet &patterns, const TypeConverter &converter,
3917 PatternBenefit benefit) {
3918 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3919 converter, patterns.getContext(), benefit);
3920}
3921
3922//===----------------------------------------------------------------------===//
3923// ConversionTarget
3924//===----------------------------------------------------------------------===//
3925
3926void ConversionTarget::setOpAction(OperationName op,
3927 LegalizationAction action) {
3928 legalOperations[op].action = action;
3929}
3930
3931void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3932 LegalizationAction action) {
3933 for (StringRef dialect : dialectNames)
3934 legalDialects[dialect] = action;
3935}
3936
3937auto ConversionTarget::getOpAction(OperationName op) const
3938 -> std::optional<LegalizationAction> {
3939 std::optional<LegalizationInfo> info = getOpInfo(op);
3940 return info ? info->action : std::optional<LegalizationAction>();
3941}
3942
3943auto ConversionTarget::isLegal(Operation *op) const
3944 -> std::optional<LegalOpDetails> {
3945 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3946 if (!info)
3947 return std::nullopt;
3948
3949 // Returns true if this operation instance is known to be legal.
3950 auto isOpLegal = [&] {
3951 // Handle dynamic legality either with the provided legality function.
3952 if (info->action == LegalizationAction::Dynamic) {
3953 std::optional<bool> result = info->legalityFn(op);
3954 if (result)
3955 return *result;
3956 }
3957
3958 // Otherwise, the operation is only legal if it was marked 'Legal'.
3959 return info->action == LegalizationAction::Legal;
3960 };
3961 if (!isOpLegal())
3962 return std::nullopt;
3963
3964 // This operation is legal, compute any additional legality information.
3965 LegalOpDetails legalityDetails;
3966 if (info->isRecursivelyLegal) {
3967 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3968 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3969 legalityDetails.isRecursivelyLegal =
3970 legalityFnIt->second(op).value_or(true);
3971 } else {
3972 legalityDetails.isRecursivelyLegal = true;
3973 }
3974 }
3975 return legalityDetails;
3976}
3977
3978bool ConversionTarget::isIllegal(Operation *op) const {
3979 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3980 if (!info)
3981 return false;
3982
3983 if (info->action == LegalizationAction::Dynamic) {
3984 std::optional<bool> result = info->legalityFn(op);
3985 if (!result)
3986 return false;
3987
3988 return !(*result);
3989 }
3990
3991 return info->action == LegalizationAction::Illegal;
3992}
3993
3994static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3995 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3996 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3997 if (!oldCallback)
3998 return newCallback;
3999
4000 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4001 Operation *op) -> std::optional<bool> {
4002 if (std::optional<bool> result = newCl(op))
4003 return *result;
4004
4005 return oldCl(op);
4006 };
4007 return chain;
4008}
4009
4010void ConversionTarget::setLegalityCallback(
4011 OperationName name, const DynamicLegalityCallbackFn &callback) {
4012 assert(callback && "expected valid legality callback");
4013 auto *infoIt = legalOperations.find(name);
4014 assert(infoIt != legalOperations.end() &&
4015 infoIt->second.action == LegalizationAction::Dynamic &&
4016 "expected operation to already be marked as dynamically legal");
4017 infoIt->second.legalityFn =
4018 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
4019}
4020
4021void ConversionTarget::markOpRecursivelyLegal(
4022 OperationName name, const DynamicLegalityCallbackFn &callback) {
4023 auto *infoIt = legalOperations.find(name);
4024 assert(infoIt != legalOperations.end() &&
4025 infoIt->second.action != LegalizationAction::Illegal &&
4026 "expected operation to already be marked as legal");
4027 infoIt->second.isRecursivelyLegal = true;
4028 if (callback)
4029 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
4030 std::move(opRecursiveLegalityFns[name]), callback);
4031 else
4032 opRecursiveLegalityFns.erase(name);
4033}
4034
4035void ConversionTarget::setLegalityCallback(
4036 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
4037 assert(callback && "expected valid legality callback");
4038 for (StringRef dialect : dialects)
4039 dialectLegalityFns[dialect] = composeLegalityCallbacks(
4040 std::move(dialectLegalityFns[dialect]), callback);
4041}
4042
4043void ConversionTarget::setLegalityCallback(
4044 const DynamicLegalityCallbackFn &callback) {
4045 assert(callback && "expected valid legality callback");
4046 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4047}
4048
4049auto ConversionTarget::getOpInfo(OperationName op) const
4050 -> std::optional<LegalizationInfo> {
4051 // Check for info for this specific operation.
4052 const auto *it = legalOperations.find(op);
4053 if (it != legalOperations.end())
4054 return it->second;
4055 // Check for info for the parent dialect.
4056 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4057 if (dialectIt != legalDialects.end()) {
4058 DynamicLegalityCallbackFn callback;
4059 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4060 if (dialectFn != dialectLegalityFns.end())
4061 callback = dialectFn->second;
4062 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4063 callback};
4064 }
4065 // Otherwise, check if we mark unknown operations as dynamic.
4066 if (unknownLegalityFn)
4067 return LegalizationInfo{LegalizationAction::Dynamic,
4068 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4069 return std::nullopt;
4070}
4071
4072#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4073//===----------------------------------------------------------------------===//
4074// PDL Configuration
4075//===----------------------------------------------------------------------===//
4076
4077void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4078 auto &rewriterImpl =
4079 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4080 rewriterImpl.currentTypeConverter = getTypeConverter();
4081}
4082
4083void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4084 auto &rewriterImpl =
4085 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4086 rewriterImpl.currentTypeConverter = nullptr;
4087}
4088
4089/// Remap the given value using the rewriter and the type converter in the
4090/// provided config.
4091static FailureOr<SmallVector<Value>>
4092pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4093 SmallVector<Value> mappedValues;
4094 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4095 return failure();
4096 return std::move(mappedValues);
4097}
4098
4099void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4100 patterns.getPDLPatterns().registerRewriteFunction(
4101 "convertValue",
4102 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4103 auto results = pdllConvertValues(
4104 static_cast<ConversionPatternRewriter &>(rewriter), value);
4105 if (failed(results))
4106 return failure();
4107 return results->front();
4108 });
4109 patterns.getPDLPatterns().registerRewriteFunction(
4110 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4111 return pdllConvertValues(
4112 static_cast<ConversionPatternRewriter &>(rewriter), values);
4113 });
4114 patterns.getPDLPatterns().registerRewriteFunction(
4115 "convertType",
4116 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4117 auto &rewriterImpl =
4118 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4119 if (const TypeConverter *converter =
4120 rewriterImpl.currentTypeConverter) {
4121 if (Type newType = converter->convertType(type))
4122 return newType;
4123 return failure();
4124 }
4125 return type;
4126 });
4127 patterns.getPDLPatterns().registerRewriteFunction(
4128 "convertTypes",
4129 [](PatternRewriter &rewriter,
4130 TypeRange types) -> FailureOr<SmallVector<Type>> {
4131 auto &rewriterImpl =
4132 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4133 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4134 if (!converter)
4135 return SmallVector<Type>(types);
4136
4137 SmallVector<Type> remappedTypes;
4138 if (failed(converter->convertTypes(types, remappedTypes)))
4139 return failure();
4140 return std::move(remappedTypes);
4141 });
4142}
4143#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4144
4145//===----------------------------------------------------------------------===//
4146// Op Conversion Entry Points
4147//===----------------------------------------------------------------------===//
4148
4149/// This is the type of Action that is dispatched when a conversion is applied.
4151 : public tracing::ActionImpl<ApplyConversionAction> {
4152public:
4155 static constexpr StringLiteral tag = "apply-conversion";
4156 static constexpr StringLiteral desc =
4157 "Encapsulate the application of a dialect conversion";
4158
4159 void print(raw_ostream &os) const override { os << tag; }
4160};
4161
4163 const ConversionTarget &target,
4165 ConversionConfig config,
4166 OpConversionMode mode) {
4167 if (ops.empty())
4168 return success();
4169 MLIRContext *ctx = ops.front()->getContext();
4170 LogicalResult status = success();
4171 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4173 [&] {
4174 OperationConverter opConverter(ops.front()->getContext(), target,
4175 patterns, config, mode);
4176 status = opConverter.applyConversion(ops);
4177 },
4178 irUnits);
4179 return status;
4180}
4181
4182//===----------------------------------------------------------------------===//
4183// Partial Conversion
4184//===----------------------------------------------------------------------===//
4185
4186LogicalResult mlir::applyPartialConversion(
4187 ArrayRef<Operation *> ops, const ConversionTarget &target,
4188 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4189 return applyConversion(ops, target, patterns, config,
4190 OpConversionMode::Partial);
4191}
4192LogicalResult
4193mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4194 const FrozenRewritePatternSet &patterns,
4195 ConversionConfig config) {
4196 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4197}
4198
4199//===----------------------------------------------------------------------===//
4200// Full Conversion
4201//===----------------------------------------------------------------------===//
4202
4203LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4204 const ConversionTarget &target,
4205 const FrozenRewritePatternSet &patterns,
4206 ConversionConfig config) {
4207 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4208}
4209LogicalResult mlir::applyFullConversion(Operation *op,
4210 const ConversionTarget &target,
4211 const FrozenRewritePatternSet &patterns,
4212 ConversionConfig config) {
4213 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4214}
4215
4216//===----------------------------------------------------------------------===//
4217// Analysis Conversion
4218//===----------------------------------------------------------------------===//
4219
4220/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4221/// op is a top-level module op (which is expected to be isolated from above),
4222/// return that op.
4224 // Check if there is a top-level operation within `ops`. If so, return that
4225 // op.
4226 for (Operation *op : ops) {
4227 if (!op->getParentOp()) {
4228#ifndef NDEBUG
4229 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4230 "expected top-level op to be isolated from above");
4231 for (Operation *other : ops)
4232 assert(op->isAncestor(other) &&
4233 "expected ops to have a common ancestor");
4234#endif // NDEBUG
4235 return op;
4236 }
4237 }
4238
4239 // No top-level op. Find a common ancestor.
4240 Operation *commonAncestor =
4241 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4242 for (Operation *op : ops.drop_front()) {
4243 while (!commonAncestor->isProperAncestor(op)) {
4244 commonAncestor =
4246 assert(commonAncestor &&
4247 "expected to find a common isolated from above ancestor");
4248 }
4249 }
4250
4251 return commonAncestor;
4252}
4253
4254LogicalResult mlir::applyAnalysisConversion(
4255 ArrayRef<Operation *> ops, ConversionTarget &target,
4256 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4257#ifndef NDEBUG
4258 if (config.legalizableOps)
4259 assert(config.legalizableOps->empty() && "expected empty set");
4260#endif // NDEBUG
4261
4262 // Clone closted common ancestor that is isolated from above.
4263 Operation *commonAncestor = findCommonAncestor(ops);
4264 IRMapping mapping;
4265 Operation *clonedAncestor = commonAncestor->clone(mapping);
4266 // Compute inverse IR mapping.
4267 DenseMap<Operation *, Operation *> inverseOperationMap;
4268 for (auto &it : mapping.getOperationMap())
4269 inverseOperationMap[it.second] = it.first;
4270
4271 // Convert the cloned operations. The original IR will remain unchanged.
4272 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4273 ops, [&](Operation *op) { return mapping.lookup(op); });
4274 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4275 OpConversionMode::Analysis);
4276
4277 // Remap `legalizableOps`, so that they point to the original ops and not the
4278 // cloned ops.
4279 if (config.legalizableOps) {
4280 DenseSet<Operation *> originalLegalizableOps;
4281 for (Operation *op : *config.legalizableOps)
4282 originalLegalizableOps.insert(inverseOperationMap[op]);
4283 *config.legalizableOps = std::move(originalLegalizableOps);
4284 }
4285
4286 // Erase the cloned IR.
4287 clonedAncestor->erase();
4288 return status;
4289}
4290
4291LogicalResult
4292mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4293 const FrozenRewritePatternSet &patterns,
4294 ConversionConfig config) {
4295 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4296}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:309
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
MLIRContext * context
Definition Builders.h:202
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:327
Block::iterator getPoint() const
Definition Builders.h:340
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:337
Block * getBlock() const
Definition Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:473
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:298
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.