MLIR  17.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Iterators.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/SaveAndRestore.h"
23 #include "llvm/Support/ScopedPrinter.h"
24 #include <optional>
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 #define DEBUG_TYPE "dialect-conversion"
30 
31 /// A utility function to log a successful result for the given reason.
32 template <typename... Args>
33 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
34  LLVM_DEBUG({
35  os.unindent();
36  os.startLine() << "} -> SUCCESS";
37  if (!fmt.empty())
38  os.getOStream() << " : "
39  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
40  os.getOStream() << "\n";
41  });
42 }
43 
44 /// A utility function to log a failure result for the given reason.
45 template <typename... Args>
46 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
47  LLVM_DEBUG({
48  os.unindent();
49  os.startLine() << "} -> FAILURE : "
50  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
51  << "\n";
52  });
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // ConversionValueMapping
57 //===----------------------------------------------------------------------===//
58 
59 namespace {
60 /// This class wraps a IRMapping to provide recursive lookup
61 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
62 struct ConversionValueMapping {
63  /// Lookup a mapped value within the map. If a mapping for the provided value
64  /// does not exist then return the provided value. If `desiredType` is
65  /// non-null, returns the most recently mapped value with that type. If an
66  /// operand of that type does not exist, defaults to normal behavior.
67  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
68 
69  /// Lookup a mapped value within the map, or return null if a mapping does not
70  /// exist. If a mapping exists, this follows the same behavior of
71  /// `lookupOrDefault`.
72  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
73 
74  /// Map a value to the one provided.
75  void map(Value oldVal, Value newVal) {
76  LLVM_DEBUG({
77  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
78  assert(it != oldVal && "inserting cyclic mapping");
79  });
80  mapping.map(oldVal, newVal);
81  }
82 
83  /// Try to map a value to the one provided. Returns false if a transitive
84  /// mapping from the new value to the old value already exists, true if the
85  /// map was updated.
86  bool tryMap(Value oldVal, Value newVal);
87 
88  /// Drop the last mapping for the given value.
89  void erase(Value value) { mapping.erase(value); }
90 
91  /// Returns the inverse raw value mapping (without recursive query support).
92  DenseMap<Value, SmallVector<Value>> getInverse() const {
94  for (auto &it : mapping.getValueMap())
95  inverse[it.second].push_back(it.first);
96  return inverse;
97  }
98 
99 private:
100  /// Current value mappings.
101  IRMapping mapping;
102 };
103 } // namespace
104 
105 Value ConversionValueMapping::lookupOrDefault(Value from,
106  Type desiredType) const {
107  // If there was no desired type, simply find the leaf value.
108  if (!desiredType) {
109  // If this value had a valid mapping, unmap that value as well in the case
110  // that it was also replaced.
111  while (auto mappedValue = mapping.lookupOrNull(from))
112  from = mappedValue;
113  return from;
114  }
115 
116  // Otherwise, try to find the deepest value that has the desired type.
117  Value desiredValue;
118  do {
119  if (from.getType() == desiredType)
120  desiredValue = from;
121 
122  Value mappedValue = mapping.lookupOrNull(from);
123  if (!mappedValue)
124  break;
125  from = mappedValue;
126  } while (true);
127 
128  // If the desired value was found use it, otherwise default to the leaf value.
129  return desiredValue ? desiredValue : from;
130 }
131 
132 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
133  Value result = lookupOrDefault(from, desiredType);
134  if (result == from || (desiredType && result.getType() != desiredType))
135  return nullptr;
136  return result;
137 }
138 
139 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
140  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
141  if (it == oldVal)
142  return false;
143  map(oldVal, newVal);
144  return true;
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Rewriter and Translation State
149 //===----------------------------------------------------------------------===//
150 namespace {
151 /// This class contains a snapshot of the current conversion rewriter state.
152 /// This is useful when saving and undoing a set of rewrites.
153 struct RewriterState {
154  RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
155  unsigned numReplacements, unsigned numArgReplacements,
156  unsigned numBlockActions, unsigned numIgnoredOperations,
157  unsigned numRootUpdates)
158  : numCreatedOps(numCreatedOps),
159  numUnresolvedMaterializations(numUnresolvedMaterializations),
160  numReplacements(numReplacements),
161  numArgReplacements(numArgReplacements),
162  numBlockActions(numBlockActions),
163  numIgnoredOperations(numIgnoredOperations),
164  numRootUpdates(numRootUpdates) {}
165 
166  /// The current number of created operations.
167  unsigned numCreatedOps;
168 
169  /// The current number of unresolved materializations.
170  unsigned numUnresolvedMaterializations;
171 
172  /// The current number of replacements queued.
173  unsigned numReplacements;
174 
175  /// The current number of argument replacements queued.
176  unsigned numArgReplacements;
177 
178  /// The current number of block actions performed.
179  unsigned numBlockActions;
180 
181  /// The current number of ignored operations.
182  unsigned numIgnoredOperations;
183 
184  /// The current number of operations that were updated in place.
185  unsigned numRootUpdates;
186 };
187 
188 //===----------------------------------------------------------------------===//
189 // OperationTransactionState
190 
191 /// The state of an operation that was updated by a pattern in-place. This
192 /// contains all of the necessary information to reconstruct an operation that
193 /// was updated in place.
194 class OperationTransactionState {
195 public:
196  OperationTransactionState() = default;
197  OperationTransactionState(Operation *op)
198  : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
199  operands(op->operand_begin(), op->operand_end()),
200  successors(op->successor_begin(), op->successor_end()) {}
201 
202  /// Discard the transaction state and reset the state of the original
203  /// operation.
204  void resetOperation() const {
205  op->setLoc(loc);
206  op->setAttrs(attrs);
207  op->setOperands(operands);
208  for (const auto &it : llvm::enumerate(successors))
209  op->setSuccessor(it.value(), it.index());
210  }
211 
212  /// Return the original operation of this state.
213  Operation *getOperation() const { return op; }
214 
215 private:
216  Operation *op;
217  LocationAttr loc;
218  DictionaryAttr attrs;
219  SmallVector<Value, 8> operands;
221 };
222 
223 //===----------------------------------------------------------------------===//
224 // OpReplacement
225 
226 /// This class represents one requested operation replacement via 'replaceOp' or
227 /// 'eraseOp`.
228 struct OpReplacement {
229  OpReplacement(TypeConverter *converter = nullptr) : converter(converter) {}
230 
231  /// An optional type converter that can be used to materialize conversions
232  /// between the new and old values if necessary.
233  TypeConverter *converter;
234 };
235 
236 //===----------------------------------------------------------------------===//
237 // BlockAction
238 
239 /// The kind of the block action performed during the rewrite. Actions can be
240 /// undone if the conversion fails.
241 enum class BlockActionKind {
242  Create,
243  Erase,
244  Inline,
245  Move,
246  Split,
247  TypeConversion
248 };
249 
250 /// Original position of the given block in its parent region. During undo
251 /// actions, the block needs to be placed after `insertAfterBlock`.
252 struct BlockPosition {
253  Region *region;
254  Block *insertAfterBlock;
255 };
256 
257 /// Information needed to undo inlining actions.
258 /// - the source block
259 /// - the first inlined operation (could be null if the source block was empty)
260 /// - the last inlined operation (could be null if the source block was empty)
261 struct InlineInfo {
262  Block *sourceBlock;
263  Operation *firstInlinedInst;
264  Operation *lastInlinedInst;
265 };
266 
267 /// The storage class for an undoable block action (one of BlockActionKind),
268 /// contains the information necessary to undo this action.
269 struct BlockAction {
270  static BlockAction getCreate(Block *block) {
271  return {BlockActionKind::Create, block, {}};
272  }
273  static BlockAction getErase(Block *block, BlockPosition originalPosition) {
274  return {BlockActionKind::Erase, block, {originalPosition}};
275  }
276  static BlockAction getInline(Block *block, Block *srcBlock,
277  Block::iterator before) {
278  BlockAction action{BlockActionKind::Inline, block, {}};
279  action.inlineInfo = {srcBlock,
280  srcBlock->empty() ? nullptr : &srcBlock->front(),
281  srcBlock->empty() ? nullptr : &srcBlock->back()};
282  return action;
283  }
284  static BlockAction getMove(Block *block, BlockPosition originalPosition) {
285  return {BlockActionKind::Move, block, {originalPosition}};
286  }
287  static BlockAction getSplit(Block *block, Block *originalBlock) {
288  BlockAction action{BlockActionKind::Split, block, {}};
289  action.originalBlock = originalBlock;
290  return action;
291  }
292  static BlockAction getTypeConversion(Block *block) {
293  return BlockAction{BlockActionKind::TypeConversion, block, {}};
294  }
295 
296  // The action kind.
297  BlockActionKind kind;
298 
299  // A pointer to the block that was created by the action.
300  Block *block;
301 
302  union {
303  // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and
304  // contains a pointer to the region that originally contained the block as
305  // well as the position of the block in that region.
306  BlockPosition originalPosition;
307  // In use if kind == BlockActionKind::Split and contains a pointer to the
308  // block that was split into two parts.
309  Block *originalBlock;
310  // In use if kind == BlockActionKind::Inline, and contains the information
311  // needed to undo the inlining.
312  InlineInfo inlineInfo;
313  };
314 };
315 
316 //===----------------------------------------------------------------------===//
317 // UnresolvedMaterialization
318 
319 /// This class represents an unresolved materialization, i.e. a materialization
320 /// that was inserted during conversion that needs to be legalized at the end of
321 /// the conversion process.
322 class UnresolvedMaterialization {
323 public:
324  /// The type of materialization.
325  enum Kind {
326  /// This materialization materializes a conversion for an illegal block
327  /// argument type, to a legal one.
328  Argument,
329 
330  /// This materialization materializes a conversion from an illegal type to a
331  /// legal one.
332  Target
333  };
334 
335  UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
336  TypeConverter *converter = nullptr,
337  Kind kind = Target, Type origOutputType = nullptr)
338  : op(op), converterAndKind(converter, kind),
339  origOutputType(origOutputType) {}
340 
341  /// Return the temporary conversion operation inserted for this
342  /// materialization.
343  UnrealizedConversionCastOp getOp() const { return op; }
344 
345  /// Return the type converter of this materialization (which may be null).
346  TypeConverter *getConverter() const { return converterAndKind.getPointer(); }
347 
348  /// Return the kind of this materialization.
349  Kind getKind() const { return converterAndKind.getInt(); }
350 
351  /// Set the kind of this materialization.
352  void setKind(Kind kind) { converterAndKind.setInt(kind); }
353 
354  /// Return the original illegal output type of the input values.
355  Type getOrigOutputType() const { return origOutputType; }
356 
357 private:
358  /// The unresolved materialization operation created during conversion.
359  UnrealizedConversionCastOp op;
360 
361  /// The corresponding type converter to use when resolving this
362  /// materialization, and the kind of this materialization.
363  llvm::PointerIntPair<TypeConverter *, 1, Kind> converterAndKind;
364 
365  /// The original output type. This is only used for argument conversions.
366  Type origOutputType;
367 };
368 } // namespace
369 
370 /// Build an unresolved materialization operation given an output type and set
371 /// of input operands.
373  UnresolvedMaterialization::Kind kind, Block *insertBlock,
374  Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
375  Type origOutputType, TypeConverter *converter,
376  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
377  // Avoid materializing an unnecessary cast.
378  if (inputs.size() == 1 && inputs.front().getType() == outputType)
379  return inputs.front();
380 
381  // Create an unresolved materialization. We use a new OpBuilder to avoid
382  // tracking the materialization like we do for other operations.
383  OpBuilder builder(insertBlock, insertPt);
384  auto convertOp =
385  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
386  unresolvedMaterializations.emplace_back(convertOp, converter, kind,
387  origOutputType);
388  return convertOp.getResult(0);
389 }
391  PatternRewriter &rewriter, Location loc, ValueRange inputs,
392  Type origOutputType, Type outputType, TypeConverter *converter,
393  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
396  rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
397  converter, unresolvedMaterializations);
398 }
400  Location loc, Value input, Type outputType, TypeConverter *converter,
401  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
402  Block *insertBlock = input.getParentBlock();
403  Block::iterator insertPt = insertBlock->begin();
404  if (OpResult inputRes = input.dyn_cast<OpResult>())
405  insertPt = ++inputRes.getOwner()->getIterator();
406 
408  UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
409  outputType, outputType, converter, unresolvedMaterializations);
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // ArgConverter
414 //===----------------------------------------------------------------------===//
415 namespace {
416 /// This class provides a simple interface for converting the types of block
417 /// arguments. This is done by creating a new block that contains the new legal
418 /// types and extracting the block that contains the old illegal types to allow
419 /// for undoing pending rewrites in the case of failure.
420 struct ArgConverter {
421  ArgConverter(
422  PatternRewriter &rewriter,
423  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
424  : rewriter(rewriter),
425  unresolvedMaterializations(unresolvedMaterializations) {}
426 
427  /// This structure contains the information pertaining to an argument that has
428  /// been converted.
429  struct ConvertedArgInfo {
430  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
431  Value castValue = nullptr)
432  : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
433 
434  /// The start index of in the new argument list that contains arguments that
435  /// replace the original.
436  unsigned newArgIdx;
437 
438  /// The number of arguments that replaced the original argument.
439  unsigned newArgSize;
440 
441  /// The cast value that was created to cast from the new arguments to the
442  /// old. This only used if 'newArgSize' > 1.
443  Value castValue;
444  };
445 
446  /// This structure contains information pertaining to a block that has had its
447  /// signature converted.
448  struct ConvertedBlockInfo {
449  ConvertedBlockInfo(Block *origBlock, TypeConverter *converter)
450  : origBlock(origBlock), converter(converter) {}
451 
452  /// The original block that was requested to have its signature converted.
453  Block *origBlock;
454 
455  /// The conversion information for each of the arguments. The information is
456  /// std::nullopt if the argument was dropped during conversion.
458 
459  /// The type converter used to convert the arguments.
460  TypeConverter *converter;
461  };
462 
463  /// Return if the signature of the given block has already been converted.
464  bool hasBeenConverted(Block *block) const {
465  return conversionInfo.count(block) || convertedBlocks.count(block);
466  }
467 
468  /// Set the type converter to use for the given region.
469  void setConverter(Region *region, TypeConverter *typeConverter) {
470  assert(typeConverter && "expected valid type converter");
471  regionToConverter[region] = typeConverter;
472  }
473 
474  /// Return the type converter to use for the given region, or null if there
475  /// isn't one.
476  TypeConverter *getConverter(Region *region) {
477  return regionToConverter.lookup(region);
478  }
479 
480  //===--------------------------------------------------------------------===//
481  // Rewrite Application
482  //===--------------------------------------------------------------------===//
483 
484  /// Erase any rewrites registered for the blocks within the given operation
485  /// which is about to be removed. This merely drops the rewrites without
486  /// undoing them.
487  void notifyOpRemoved(Operation *op);
488 
489  /// Cleanup and undo any generated conversions for the arguments of block.
490  /// This method replaces the new block with the original, reverting the IR to
491  /// its original state.
492  void discardRewrites(Block *block);
493 
494  /// Fully replace uses of the old arguments with the new.
495  void applyRewrites(ConversionValueMapping &mapping);
496 
497  /// Materialize any necessary conversions for converted arguments that have
498  /// live users, using the provided `findLiveUser` to search for a user that
499  /// survives the conversion process.
501  materializeLiveConversions(ConversionValueMapping &mapping,
502  OpBuilder &builder,
503  function_ref<Operation *(Value)> findLiveUser);
504 
505  //===--------------------------------------------------------------------===//
506  // Conversion
507  //===--------------------------------------------------------------------===//
508 
509  /// Attempt to convert the signature of the given block, if successful a new
510  /// block is returned containing the new arguments. Returns `block` if it did
511  /// not require conversion.
513  convertSignature(Block *block, TypeConverter *converter,
514  ConversionValueMapping &mapping,
515  SmallVectorImpl<BlockArgument> &argReplacements);
516 
517  /// Apply the given signature conversion on the given block. The new block
518  /// containing the updated signature is returned. If no conversions were
519  /// necessary, e.g. if the block has no arguments, `block` is returned.
520  /// `converter` is used to generate any necessary cast operations that
521  /// translate between the origin argument types and those specified in the
522  /// signature conversion.
523  Block *applySignatureConversion(
524  Block *block, TypeConverter *converter,
525  TypeConverter::SignatureConversion &signatureConversion,
526  ConversionValueMapping &mapping,
527  SmallVectorImpl<BlockArgument> &argReplacements);
528 
529  /// Insert a new conversion into the cache.
530  void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
531 
532  /// A collection of blocks that have had their arguments converted. This is a
533  /// map from the new replacement block, back to the original block.
534  llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
535 
536  /// The set of original blocks that were converted.
537  DenseSet<Block *> convertedBlocks;
538 
539  /// A mapping from valid regions, to those containing the original blocks of a
540  /// conversion.
542 
543  /// A mapping of regions to type converters that should be used when
544  /// converting the arguments of blocks within that region.
545  DenseMap<Region *, TypeConverter *> regionToConverter;
546 
547  /// The pattern rewriter to use when materializing conversions.
548  PatternRewriter &rewriter;
549 
550  /// An ordered set of unresolved materializations during conversion.
551  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
552 };
553 } // namespace
554 
555 //===----------------------------------------------------------------------===//
556 // Rewrite Application
557 
558 void ArgConverter::notifyOpRemoved(Operation *op) {
559  if (conversionInfo.empty())
560  return;
561 
562  for (Region &region : op->getRegions()) {
563  for (Block &block : region) {
564  // Drop any rewrites from within.
565  for (Operation &nestedOp : block)
566  if (nestedOp.getNumRegions())
567  notifyOpRemoved(&nestedOp);
568 
569  // Check if this block was converted.
570  auto it = conversionInfo.find(&block);
571  if (it == conversionInfo.end())
572  continue;
573 
574  // Drop all uses of the original arguments and delete the original block.
575  Block *origBlock = it->second.origBlock;
576  for (BlockArgument arg : origBlock->getArguments())
577  arg.dropAllUses();
578  conversionInfo.erase(it);
579  }
580  }
581 }
582 
583 void ArgConverter::discardRewrites(Block *block) {
584  auto it = conversionInfo.find(block);
585  if (it == conversionInfo.end())
586  return;
587  Block *origBlock = it->second.origBlock;
588 
589  // Drop all uses of the new block arguments and replace uses of the new block.
590  for (int i = block->getNumArguments() - 1; i >= 0; --i)
591  block->getArgument(i).dropAllUses();
592  block->replaceAllUsesWith(origBlock);
593 
594  // Move the operations back the original block and the delete the new block.
595  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
596  origBlock->moveBefore(block);
597  block->erase();
598 
599  convertedBlocks.erase(origBlock);
600  conversionInfo.erase(it);
601 }
602 
603 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
604  for (auto &info : conversionInfo) {
605  ConvertedBlockInfo &blockInfo = info.second;
606  Block *origBlock = blockInfo.origBlock;
607 
608  // Process the remapping for each of the original arguments.
609  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
610  std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
611  BlockArgument origArg = origBlock->getArgument(i);
612 
613  // Handle the case of a 1->0 value mapping.
614  if (!argInfo) {
615  if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
616  origArg.replaceAllUsesWith(newArg);
617  continue;
618  }
619 
620  // Otherwise this is a 1->1+ value mapping.
621  Value castValue = argInfo->castValue;
622  assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
623 
624  // If the argument is still used, replace it with the generated cast.
625  if (!origArg.use_empty()) {
626  origArg.replaceAllUsesWith(
627  mapping.lookupOrDefault(castValue, origArg.getType()));
628  }
629  }
630  }
631 }
632 
633 LogicalResult ArgConverter::materializeLiveConversions(
634  ConversionValueMapping &mapping, OpBuilder &builder,
635  function_ref<Operation *(Value)> findLiveUser) {
636  for (auto &info : conversionInfo) {
637  Block *newBlock = info.first;
638  ConvertedBlockInfo &blockInfo = info.second;
639  Block *origBlock = blockInfo.origBlock;
640 
641  // Process the remapping for each of the original arguments.
642  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
643  // If the type of this argument changed and the argument is still live, we
644  // need to materialize a conversion.
645  BlockArgument origArg = origBlock->getArgument(i);
646  if (mapping.lookupOrNull(origArg, origArg.getType()))
647  continue;
648  Operation *liveUser = findLiveUser(origArg);
649  if (!liveUser)
650  continue;
651 
652  Value replacementValue = mapping.lookupOrDefault(origArg);
653  bool isDroppedArg = replacementValue == origArg;
654  if (isDroppedArg)
655  rewriter.setInsertionPointToStart(newBlock);
656  else
657  rewriter.setInsertionPointAfterValue(replacementValue);
658  Value newArg;
659  if (blockInfo.converter) {
660  newArg = blockInfo.converter->materializeSourceConversion(
661  rewriter, origArg.getLoc(), origArg.getType(),
662  isDroppedArg ? ValueRange() : ValueRange(replacementValue));
663  assert((!newArg || newArg.getType() == origArg.getType()) &&
664  "materialization hook did not provide a value of the expected "
665  "type");
666  }
667  if (!newArg) {
669  emitError(origArg.getLoc())
670  << "failed to materialize conversion for block argument #" << i
671  << " that remained live after conversion, type was "
672  << origArg.getType();
673  if (!isDroppedArg)
674  diag << ", with target type " << replacementValue.getType();
675  diag.attachNote(liveUser->getLoc())
676  << "see existing live user here: " << *liveUser;
677  return failure();
678  }
679  mapping.map(origArg, newArg);
680  }
681  }
682  return success();
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // Conversion
687 
688 FailureOr<Block *> ArgConverter::convertSignature(
689  Block *block, TypeConverter *converter, ConversionValueMapping &mapping,
690  SmallVectorImpl<BlockArgument> &argReplacements) {
691  // Check if the block was already converted. If the block is detached,
692  // conservatively assume it is going to be deleted.
693  if (hasBeenConverted(block) || !block->getParent())
694  return block;
695  // If a converter wasn't provided, and the block wasn't already converted,
696  // there is nothing we can do.
697  if (!converter)
698  return failure();
699 
700  // Try to convert the signature for the block with the provided converter.
701  if (auto conversion = converter->convertBlockSignature(block))
702  return applySignatureConversion(block, converter, *conversion, mapping,
703  argReplacements);
704  return failure();
705 }
706 
707 Block *ArgConverter::applySignatureConversion(
708  Block *block, TypeConverter *converter,
709  TypeConverter::SignatureConversion &signatureConversion,
710  ConversionValueMapping &mapping,
711  SmallVectorImpl<BlockArgument> &argReplacements) {
712  // If no arguments are being changed or added, there is nothing to do.
713  unsigned origArgCount = block->getNumArguments();
714  auto convertedTypes = signatureConversion.getConvertedTypes();
715  if (origArgCount == 0 && convertedTypes.empty())
716  return block;
717 
718  // Split the block at the beginning to get a new block to use for the updated
719  // signature.
720  Block *newBlock = block->splitBlock(block->begin());
721  block->replaceAllUsesWith(newBlock);
722 
723  // FIXME: We should map the new arguments to proper locations.
724  SmallVector<Location> newLocs(convertedTypes.size(),
725  rewriter.getUnknownLoc());
726  SmallVector<Value, 4> newArgRange(
727  newBlock->addArguments(convertedTypes, newLocs));
728  ArrayRef<Value> newArgs(newArgRange);
729 
730  // Remap each of the original arguments as determined by the signature
731  // conversion.
732  ConvertedBlockInfo info(block, converter);
733  info.argInfo.resize(origArgCount);
734 
735  OpBuilder::InsertionGuard guard(rewriter);
736  rewriter.setInsertionPointToStart(newBlock);
737  for (unsigned i = 0; i != origArgCount; ++i) {
738  auto inputMap = signatureConversion.getInputMapping(i);
739  if (!inputMap)
740  continue;
741  BlockArgument origArg = block->getArgument(i);
742 
743  // If inputMap->replacementValue is not nullptr, then the argument is
744  // dropped and a replacement value is provided to be the remappedValue.
745  if (inputMap->replacementValue) {
746  assert(inputMap->size == 0 &&
747  "invalid to provide a replacement value when the argument isn't "
748  "dropped");
749  mapping.map(origArg, inputMap->replacementValue);
750  argReplacements.push_back(origArg);
751  continue;
752  }
753 
754  // Otherwise, this is a 1->1+ mapping.
755  auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
756  Value newArg;
757 
758  // If this is a 1->1 mapping and the types of new and replacement arguments
759  // match (i.e. it's an identity map), then the argument is mapped to its
760  // original type.
761  // FIXME: We simply pass through the replacement argument if there wasn't a
762  // converter, which isn't great as it allows implicit type conversions to
763  // appear. We should properly restructure this code to handle cases where a
764  // converter isn't provided and also to properly handle the case where an
765  // argument materialization is actually a temporary source materialization
766  // (e.g. in the case of 1->N).
767  if (replArgs.size() == 1 &&
768  (!converter || replArgs[0].getType() == origArg.getType())) {
769  newArg = replArgs.front();
770  } else {
771  Type origOutputType = origArg.getType();
772 
773  // Legalize the argument output type.
774  Type outputType = origOutputType;
775  if (Type legalOutputType = converter->convertType(outputType))
776  outputType = legalOutputType;
777 
779  rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
780  converter, unresolvedMaterializations);
781  }
782 
783  mapping.map(origArg, newArg);
784  argReplacements.push_back(origArg);
785  info.argInfo[i] =
786  ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
787  }
788 
789  // Remove the original block from the region and return the new one.
790  insertConversion(newBlock, std::move(info));
791  return newBlock;
792 }
793 
794 void ArgConverter::insertConversion(Block *newBlock,
795  ConvertedBlockInfo &&info) {
796  // Get a region to insert the old block.
797  Region *region = newBlock->getParent();
798  std::unique_ptr<Region> &mappedRegion = regionMapping[region];
799  if (!mappedRegion)
800  mappedRegion = std::make_unique<Region>(region->getParentOp());
801 
802  // Move the original block to the mapped region and emplace the conversion.
803  mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
804  info.origBlock->getIterator());
805  convertedBlocks.insert(info.origBlock);
806  conversionInfo.insert({newBlock, std::move(info)});
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // ConversionPatternRewriterImpl
811 //===----------------------------------------------------------------------===//
812 namespace mlir {
813 namespace detail {
816  : argConverter(rewriter, unresolvedMaterializations),
817  notifyCallback(nullptr) {}
818 
819  /// Cleanup and destroy any generated rewrite operations. This method is
820  /// invoked when the conversion process fails.
821  void discardRewrites();
822 
823  /// Apply all requested operation rewrites. This method is invoked when the
824  /// conversion process succeeds.
825  void applyRewrites();
826 
827  //===--------------------------------------------------------------------===//
828  // State Management
829  //===--------------------------------------------------------------------===//
830 
831  /// Return the current state of the rewriter.
832  RewriterState getCurrentState();
833 
834  /// Reset the state of the rewriter to a previously saved point.
835  void resetState(RewriterState state);
836 
837  /// Erase any blocks that were unlinked from their regions and stored in block
838  /// actions.
839  void eraseDanglingBlocks();
840 
841  /// Undo the block actions (motions, splits) one by one in reverse order until
842  /// "numActionsToKeep" actions remains.
843  void undoBlockActions(unsigned numActionsToKeep = 0);
844 
845  /// Remap the given values to those with potentially different types. Returns
846  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
847  /// is the tag used when describing a value within a diagnostic, e.g.
848  /// "operand".
849  LogicalResult remapValues(StringRef valueDiagTag,
850  std::optional<Location> inputLoc,
851  PatternRewriter &rewriter, ValueRange values,
852  SmallVectorImpl<Value> &remapped);
853 
854  /// Returns true if the given operation is ignored, and does not need to be
855  /// converted.
856  bool isOpIgnored(Operation *op) const;
857 
858  /// Recursively marks the nested operations under 'op' as ignored. This
859  /// removes them from being considered for legalization.
860  void markNestedOpsIgnored(Operation *op);
861 
862  //===--------------------------------------------------------------------===//
863  // Type Conversion
864  //===--------------------------------------------------------------------===//
865 
866  /// Convert the signature of the given block.
867  FailureOr<Block *> convertBlockSignature(
868  Block *block, TypeConverter *converter,
869  TypeConverter::SignatureConversion *conversion = nullptr);
870 
871  /// Apply a signature conversion on the given region, using `converter` for
872  /// materializations if not null.
873  Block *
874  applySignatureConversion(Region *region,
876  TypeConverter *converter);
877 
878  /// Convert the types of block arguments within the given region.
880  convertRegionTypes(Region *region, TypeConverter &converter,
881  TypeConverter::SignatureConversion *entryConversion);
882 
883  /// Convert the types of non-entry block arguments within the given region.
884  LogicalResult convertNonEntryRegionTypes(
885  Region *region, TypeConverter &converter,
886  ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
887 
888  //===--------------------------------------------------------------------===//
889  // Rewriter Notification Hooks
890  //===--------------------------------------------------------------------===//
891 
892  /// PatternRewriter hook for replacing the results of an operation.
893  void notifyOpReplaced(Operation *op, ValueRange newValues);
894 
895  /// Notifies that a block is about to be erased.
896  void notifyBlockIsBeingErased(Block *block);
897 
898  /// Notifies that a block was created.
899  void notifyCreatedBlock(Block *block);
900 
901  /// Notifies that a block was split.
902  void notifySplitBlock(Block *block, Block *continuation);
903 
904  /// Notifies that a block is being inlined into another block.
905  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
906  Block::iterator before);
907 
908  /// Notifies that the blocks of a region are about to be moved.
909  void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
910  Region::iterator before);
911 
912  /// Notifies that a pattern match failed for the given reason.
914  notifyMatchFailure(Location loc,
915  function_ref<void(Diagnostic &)> reasonCallback);
916 
917  //===--------------------------------------------------------------------===//
918  // State
919  //===--------------------------------------------------------------------===//
920 
921  // Mapping between replaced values that differ in type. This happens when
922  // replacing a value with one of a different type.
923  ConversionValueMapping mapping;
924 
925  /// Utility used to convert block arguments.
926  ArgConverter argConverter;
927 
928  /// Ordered vector of all of the newly created operations during conversion.
930 
931  /// Ordered vector of all unresolved type conversion materializations during
932  /// conversion.
934 
935  /// Ordered map of requested operation replacements.
936  llvm::MapVector<Operation *, OpReplacement> replacements;
937 
938  /// Ordered vector of any requested block argument replacements.
940 
941  /// Ordered list of block operations (creations, splits, motions).
943 
944  /// A set of operations that should no longer be considered for legalization,
945  /// but were not directly replace/erased/etc. by a pattern. These are
946  /// generally child operations of other operations who were
947  /// replaced/erased/etc. This is not meant to be an exhaustive list of all
948  /// operations, but the minimal set that can be used to detect if a given
949  /// operation should be `ignored`. For example, we may add the operations that
950  /// define non-empty regions to the set, but not any of the others. This
951  /// simplifies the amount of memory needed as we can query if the parent
952  /// operation was ignored.
954 
955  /// A transaction state for each of operations that were updated in-place.
957 
958  /// A vector of indices into `replacements` of operations that were replaced
959  /// with values with different result types than the original operation, e.g.
960  /// 1->N conversion of some kind.
962 
963  /// The current type converter, or nullptr if no type converter is currently
964  /// active.
965  TypeConverter *currentTypeConverter = nullptr;
966 
967  /// This allows the user to collect the match failure message.
969 
970 #ifndef NDEBUG
971  /// A set of operations that have pending updates. This tracking isn't
972  /// strictly necessary, and is thus only active during debug builds for extra
973  /// verification.
975 
976  /// A logger used to emit diagnostics during the conversion process.
977  llvm::ScopedPrinter logger{llvm::dbgs()};
978 #endif
979 };
980 } // namespace detail
981 } // namespace mlir
982 
983 /// Detach any operations nested in the given operation from their parent
984 /// blocks, and erase the given operation. This can be used when the nested
985 /// operations are scheduled for erasure themselves, so deleting the regions of
986 /// the given operation together with their content would result in double-free.
987 /// This happens, for example, when rolling back op creation in the reverse
988 /// order and if the nested ops were created before the parent op. This function
989 /// does not need to collect nested ops recursively because it is expected to
990 /// also be called for each nested op when it is about to be deleted.
991 static void detachNestedAndErase(Operation *op) {
992  for (Region &region : op->getRegions()) {
993  for (Block &block : region.getBlocks()) {
994  while (!block.getOperations().empty())
995  block.getOperations().remove(block.getOperations().begin());
996  block.dropAllDefinedValueUses();
997  }
998  }
999  op->dropAllUses();
1000  op->erase();
1001 }
1002 
1004  // Reset any operations that were updated in place.
1005  for (auto &state : rootUpdates)
1006  state.resetOperation();
1007 
1008  undoBlockActions();
1009 
1010  // Remove any newly created ops.
1011  for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
1012  detachNestedAndErase(materialization.getOp());
1013  for (auto *op : llvm::reverse(createdOps))
1015 }
1016 
1018  // Apply all of the rewrites replacements requested during conversion.
1019  for (auto &repl : replacements) {
1020  for (OpResult result : repl.first->getResults())
1021  if (Value newValue = mapping.lookupOrNull(result, result.getType()))
1022  result.replaceAllUsesWith(newValue);
1023 
1024  // If this operation defines any regions, drop any pending argument
1025  // rewrites.
1026  if (repl.first->getNumRegions())
1027  argConverter.notifyOpRemoved(repl.first);
1028  }
1029 
1030  // Apply all of the requested argument replacements.
1031  for (BlockArgument arg : argReplacements) {
1032  Value repl = mapping.lookupOrNull(arg, arg.getType());
1033  if (!repl)
1034  continue;
1035 
1036  if (repl.isa<BlockArgument>()) {
1037  arg.replaceAllUsesWith(repl);
1038  continue;
1039  }
1040 
1041  // If the replacement value is an operation, we check to make sure that we
1042  // don't replace uses that are within the parent operation of the
1043  // replacement value.
1044  Operation *replOp = repl.cast<OpResult>().getOwner();
1045  Block *replBlock = replOp->getBlock();
1046  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
1047  Operation *user = operand.getOwner();
1048  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1049  });
1050  }
1051 
1052  // Drop all of the unresolved materialization operations created during
1053  // conversion.
1054  for (auto &mat : unresolvedMaterializations) {
1055  mat.getOp()->dropAllUses();
1056  mat.getOp()->erase();
1057  }
1058 
1059  // In a second pass, erase all of the replaced operations in reverse. This
1060  // allows processing nested operations before their parent region is
1061  // destroyed. Because we process in reverse order, producers may be deleted
1062  // before their users (a pattern deleting a producer and then the consumer)
1063  // so we first drop all uses explicitly.
1064  for (auto &repl : llvm::reverse(replacements)) {
1065  repl.first->dropAllUses();
1066  repl.first->erase();
1067  }
1068 
1069  argConverter.applyRewrites(mapping);
1070 
1071  // Now that the ops have been erased, also erase dangling blocks.
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // State Management
1077 
1079  return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
1080  replacements.size(), argReplacements.size(),
1081  blockActions.size(), ignoredOps.size(),
1082  rootUpdates.size());
1083 }
1084 
1086  // Reset any operations that were updated in place.
1087  for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
1088  rootUpdates[i].resetOperation();
1089  rootUpdates.resize(state.numRootUpdates);
1090 
1091  // Reset any replaced arguments.
1092  for (BlockArgument replacedArg :
1093  llvm::drop_begin(argReplacements, state.numArgReplacements))
1094  mapping.erase(replacedArg);
1095  argReplacements.resize(state.numArgReplacements);
1096 
1097  // Undo any block actions.
1098  undoBlockActions(state.numBlockActions);
1099 
1100  // Reset any replaced operations and undo any saved mappings.
1101  for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
1102  for (auto result : repl.first->getResults())
1103  mapping.erase(result);
1104  while (replacements.size() != state.numReplacements)
1105  replacements.pop_back();
1106 
1107  // Pop all of the newly inserted materializations.
1108  while (unresolvedMaterializations.size() !=
1109  state.numUnresolvedMaterializations) {
1110  UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
1111  UnrealizedConversionCastOp op = mat.getOp();
1112 
1113  // If this was a target materialization, drop the mapping that was inserted.
1114  if (mat.getKind() == UnresolvedMaterialization::Target) {
1115  for (Value input : op->getOperands())
1116  mapping.erase(input);
1117  }
1119  }
1120 
1121  // Pop all of the newly created operations.
1122  while (createdOps.size() != state.numCreatedOps) {
1124  createdOps.pop_back();
1125  }
1126 
1127  // Pop all of the recorded ignored operations that are no longer valid.
1128  while (ignoredOps.size() != state.numIgnoredOperations)
1129  ignoredOps.pop_back();
1130 
1131  // Reset operations with changed results.
1132  while (!operationsWithChangedResults.empty() &&
1133  operationsWithChangedResults.back() >= state.numReplacements)
1134  operationsWithChangedResults.pop_back();
1135 }
1136 
1138  for (auto &action : blockActions)
1139  if (action.kind == BlockActionKind::Erase)
1140  delete action.block;
1141 }
1142 
1144  unsigned numActionsToKeep) {
1145  for (auto &action :
1146  llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
1147  switch (action.kind) {
1148  // Delete the created block.
1149  case BlockActionKind::Create: {
1150  // Unlink all of the operations within this block, they will be deleted
1151  // separately.
1152  auto &blockOps = action.block->getOperations();
1153  while (!blockOps.empty())
1154  blockOps.remove(blockOps.begin());
1155  action.block->dropAllDefinedValueUses();
1156  action.block->erase();
1157  break;
1158  }
1159  // Put the block (owned by action) back into its original position.
1160  case BlockActionKind::Erase: {
1161  auto &blockList = action.originalPosition.region->getBlocks();
1162  Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1163  blockList.insert((insertAfterBlock
1164  ? std::next(Region::iterator(insertAfterBlock))
1165  : blockList.begin()),
1166  action.block);
1167  break;
1168  }
1169  // Put the instructions from the destination block (owned by the action)
1170  // back into the source block.
1171  case BlockActionKind::Inline: {
1172  Block *sourceBlock = action.inlineInfo.sourceBlock;
1173  if (action.inlineInfo.firstInlinedInst) {
1174  assert(action.inlineInfo.lastInlinedInst && "expected operation");
1175  sourceBlock->getOperations().splice(
1176  sourceBlock->begin(), action.block->getOperations(),
1177  Block::iterator(action.inlineInfo.firstInlinedInst),
1178  ++Block::iterator(action.inlineInfo.lastInlinedInst));
1179  }
1180  break;
1181  }
1182  // Move the block back to its original position.
1183  case BlockActionKind::Move: {
1184  Region *originalRegion = action.originalPosition.region;
1185  Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1186  originalRegion->getBlocks().splice(
1187  (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1188  : originalRegion->end()),
1189  action.block->getParent()->getBlocks(), action.block);
1190  break;
1191  }
1192  // Merge back the block that was split out.
1193  case BlockActionKind::Split: {
1194  action.originalBlock->getOperations().splice(
1195  action.originalBlock->end(), action.block->getOperations());
1196  action.block->dropAllDefinedValueUses();
1197  action.block->erase();
1198  break;
1199  }
1200  // Undo the type conversion.
1201  case BlockActionKind::TypeConversion: {
1202  argConverter.discardRewrites(action.block);
1203  break;
1204  }
1205  }
1206  }
1207  blockActions.resize(numActionsToKeep);
1208 }
1209 
1211  StringRef valueDiagTag, std::optional<Location> inputLoc,
1212  PatternRewriter &rewriter, ValueRange values,
1213  SmallVectorImpl<Value> &remapped) {
1214  remapped.reserve(llvm::size(values));
1215 
1216  SmallVector<Type, 1> legalTypes;
1217  for (const auto &it : llvm::enumerate(values)) {
1218  Value operand = it.value();
1219  Type origType = operand.getType();
1220 
1221  // If a converter was provided, get the desired legal types for this
1222  // operand.
1223  Type desiredType;
1224  if (currentTypeConverter) {
1225  // If there is no legal conversion, fail to match this pattern.
1226  legalTypes.clear();
1227  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1228  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1229  return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1230  diag << "unable to convert type for " << valueDiagTag << " #"
1231  << it.index() << ", type was " << origType;
1232  });
1233  }
1234  // TODO: There currently isn't any mechanism to do 1->N type conversion
1235  // via the PatternRewriter replacement API, so for now we just ignore it.
1236  if (legalTypes.size() == 1)
1237  desiredType = legalTypes.front();
1238  } else {
1239  // TODO: What we should do here is just set `desiredType` to `origType`
1240  // and then handle the necessary type conversions after the conversion
1241  // process has finished. Unfortunately a lot of patterns currently rely on
1242  // receiving the new operands even if the types change, so we keep the
1243  // original behavior here for now until all of the patterns relying on
1244  // this get updated.
1245  }
1246  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1247 
1248  // Handle the case where the conversion was 1->1 and the new operand type
1249  // isn't legal.
1250  Type newOperandType = newOperand.getType();
1251  if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1252  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1254  operandLoc, newOperand, desiredType, currentTypeConverter,
1256  mapping.map(mapping.lookupOrDefault(newOperand), castValue);
1257  newOperand = castValue;
1258  }
1259  remapped.push_back(newOperand);
1260  }
1261  return success();
1262 }
1263 
1265  // Check to see if this operation was replaced or its parent ignored.
1266  return replacements.count(op) || ignoredOps.count(op->getParentOp());
1267 }
1268 
1270  // Walk this operation and collect nested operations that define non-empty
1271  // regions. We mark such operations as 'ignored' so that we know we don't have
1272  // to convert them, or their nested ops.
1273  if (op->getNumRegions() == 0)
1274  return;
1275  op->walk([&](Operation *op) {
1276  if (llvm::any_of(op->getRegions(),
1277  [](Region &region) { return !region.empty(); }))
1278  ignoredOps.insert(op);
1279  });
1280 }
1281 
1282 //===----------------------------------------------------------------------===//
1283 // Type Conversion
1284 
1286  Block *block, TypeConverter *converter,
1287  TypeConverter::SignatureConversion *conversion) {
1288  FailureOr<Block *> result =
1289  conversion ? argConverter.applySignatureConversion(
1290  block, converter, *conversion, mapping, argReplacements)
1291  : argConverter.convertSignature(block, converter, mapping,
1292  argReplacements);
1293  if (failed(result))
1294  return failure();
1295  if (Block *newBlock = *result) {
1296  if (newBlock != block)
1297  blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1298  }
1299  return result;
1300 }
1301 
1303  Region *region, TypeConverter::SignatureConversion &conversion,
1304  TypeConverter *converter) {
1305  if (!region->empty())
1306  return *convertBlockSignature(&region->front(), converter, &conversion);
1307  return nullptr;
1308 }
1309 
1311  Region *region, TypeConverter &converter,
1312  TypeConverter::SignatureConversion *entryConversion) {
1313  argConverter.setConverter(region, &converter);
1314  if (region->empty())
1315  return nullptr;
1316 
1317  if (failed(convertNonEntryRegionTypes(region, converter)))
1318  return failure();
1319 
1320  FailureOr<Block *> newEntry =
1321  convertBlockSignature(&region->front(), &converter, entryConversion);
1322  return newEntry;
1323 }
1324 
1326  Region *region, TypeConverter &converter,
1328  argConverter.setConverter(region, &converter);
1329  if (region->empty())
1330  return success();
1331 
1332  // Convert the arguments of each block within the region.
1333  int blockIdx = 0;
1334  assert((blockConversions.empty() ||
1335  blockConversions.size() == region->getBlocks().size() - 1) &&
1336  "expected either to provide no SignatureConversions at all or to "
1337  "provide a SignatureConversion for each non-entry block");
1338 
1339  for (Block &block :
1340  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1341  TypeConverter::SignatureConversion *blockConversion =
1342  blockConversions.empty()
1343  ? nullptr
1344  : const_cast<TypeConverter::SignatureConversion *>(
1345  &blockConversions[blockIdx++]);
1346 
1347  if (failed(convertBlockSignature(&block, &converter, blockConversion)))
1348  return failure();
1349  }
1350  return success();
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // Rewriter Notification Hooks
1355 
1357  ValueRange newValues) {
1358  assert(newValues.size() == op->getNumResults());
1359  assert(!replacements.count(op) && "operation was already replaced");
1360 
1361  // Track if any of the results changed, e.g. erased and replaced with null.
1362  bool resultChanged = false;
1363 
1364  // Create mappings for each of the new result values.
1365  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1366  if (!newValue) {
1367  resultChanged = true;
1368  continue;
1369  }
1370  // Remap, and check for any result type changes.
1371  mapping.map(result, newValue);
1372  resultChanged |= (newValue.getType() != result.getType());
1373  }
1374  if (resultChanged)
1375  operationsWithChangedResults.push_back(replacements.size());
1376 
1377  // Record the requested operation replacement.
1378  replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
1379 
1380  // Mark this operation as recursively ignored so that we don't need to
1381  // convert any nested operations.
1383 }
1384 
1386  Region *region = block->getParent();
1387  Block *origPrevBlock = block->getPrevNode();
1388  blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1389 }
1390 
1392  blockActions.push_back(BlockAction::getCreate(block));
1393 }
1394 
1396  Block *continuation) {
1397  blockActions.push_back(BlockAction::getSplit(continuation, block));
1398 }
1399 
1401  Block *block, Block *srcBlock, Block::iterator before) {
1402  blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
1403 }
1404 
1406  Region &region, Region &parent, Region::iterator before) {
1407  if (region.empty())
1408  return;
1409  Block *laterBlock = &region.back();
1410  for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1411  blockActions.push_back(
1412  BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1413  laterBlock = &earlierBlock;
1414  }
1415  blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1416 }
1417 
1419  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1420  LLVM_DEBUG({
1422  reasonCallback(diag);
1423  logger.startLine() << "** Failure : " << diag.str() << "\n";
1424  if (notifyCallback)
1426  });
1427  return failure();
1428 }
1429 
1430 //===----------------------------------------------------------------------===//
1431 // ConversionPatternRewriter
1432 //===----------------------------------------------------------------------===//
1433 
1435  : PatternRewriter(ctx),
1436  impl(new detail::ConversionPatternRewriterImpl(*this)) {
1437  setListener(this);
1438 }
1439 
1441 
1443  Operation *op, ValueRange newValues, bool *allUsesReplaced,
1444  llvm::unique_function<bool(OpOperand &) const> functor) {
1445  // TODO: To support this we will need to rework a bit of how replacements are
1446  // tracked, given that this isn't guranteed to replace all of the uses of an
1447  // operation. The main change is that now an operation can be replaced
1448  // multiple times, in parts. The current "set" based tracking is mainly useful
1449  // for tracking if a replaced operation should be ignored, i.e. if all of the
1450  // uses will be replaced.
1451  llvm_unreachable(
1452  "replaceOpWithIf is currently not supported by DialectConversion");
1453 }
1454 
1456  LLVM_DEBUG({
1457  impl->logger.startLine()
1458  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1459  });
1460  impl->notifyOpReplaced(op, newValues);
1461 }
1462 
1464  LLVM_DEBUG({
1465  impl->logger.startLine()
1466  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1467  });
1468  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1469  impl->notifyOpReplaced(op, nullRepls);
1470 }
1471 
1473  impl->notifyBlockIsBeingErased(block);
1474 
1475  // Mark all ops for erasure.
1476  for (Operation &op : *block)
1477  eraseOp(&op);
1478 
1479  // Unlink the block from its parent region. The block is kept in the block
1480  // action and will be actually destroyed when rewrites are applied. This
1481  // allows us to keep the operations in the block live and undo the removal by
1482  // re-inserting the block.
1483  block->getParent()->getBlocks().remove(block);
1484 }
1485 
1487  Region *region, TypeConverter::SignatureConversion &conversion,
1488  TypeConverter *converter) {
1489  return impl->applySignatureConversion(region, conversion, converter);
1490 }
1491 
1493  Region *region, TypeConverter &converter,
1494  TypeConverter::SignatureConversion *entryConversion) {
1495  return impl->convertRegionTypes(region, converter, entryConversion);
1496 }
1497 
1499  Region *region, TypeConverter &converter,
1501  return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
1502 }
1503 
1505  Value to) {
1506  LLVM_DEBUG({
1507  Operation *parentOp = from.getOwner()->getParentOp();
1508  impl->logger.startLine() << "** Replace Argument : '" << from
1509  << "'(in region of '" << parentOp->getName()
1510  << "'(" << from.getOwner()->getParentOp() << ")\n";
1511  });
1512  impl->argReplacements.push_back(from);
1513  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1514 }
1515 
1517  SmallVector<Value> remappedValues;
1518  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1519  remappedValues)))
1520  return nullptr;
1521  return remappedValues.front();
1522 }
1523 
1526  SmallVectorImpl<Value> &results) {
1527  if (keys.empty())
1528  return success();
1529  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1530  results);
1531 }
1532 
1534  impl->notifyCreatedBlock(block);
1535 }
1536 
1538  Block::iterator before) {
1539  auto *continuation = PatternRewriter::splitBlock(block, before);
1540  impl->notifySplitBlock(block, continuation);
1541  return continuation;
1542 }
1543 
1545  Block::iterator before,
1546  ValueRange argValues) {
1547  assert(argValues.size() == source->getNumArguments() &&
1548  "incorrect # of argument replacement values");
1549 #ifndef NDEBUG
1550  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
1551 #endif // NDEBUG
1552  // The source block will be deleted, so it should not have any users (i.e.,
1553  // there should be no predecessors).
1554  assert(llvm::all_of(source->getUsers(), opIgnored) &&
1555  "expected 'source' to have no predecessors");
1556 
1557  impl->notifyBlockBeingInlined(dest, source, before);
1558  for (auto it : llvm::zip(source->getArguments(), argValues))
1559  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1560  dest->getOperations().splice(before, source->getOperations());
1561  eraseBlock(source);
1562 }
1563 
1565  Region &parent,
1566  Region::iterator before) {
1567  impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1568  PatternRewriter::inlineRegionBefore(region, parent, before);
1569 }
1570 
1572  Region &parent,
1573  Region::iterator before,
1574  IRMapping &mapping) {
1575  if (region.empty())
1576  return;
1577 
1578  PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1579 
1580  for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
1581  Block *cloned = mapping.lookup(&b);
1582  impl->notifyCreatedBlock(cloned);
1584  [&](Operation *op) { notifyOperationInserted(op); });
1585  }
1586 }
1587 
1589  LLVM_DEBUG({
1590  impl->logger.startLine()
1591  << "** Insert : '" << op->getName() << "'(" << op << ")\n";
1592  });
1593  impl->createdOps.push_back(op);
1594 }
1595 
1597 #ifndef NDEBUG
1598  impl->pendingRootUpdates.insert(op);
1599 #endif
1600  impl->rootUpdates.emplace_back(op);
1601 }
1602 
1605  // There is nothing to do here, we only need to track the operation at the
1606  // start of the update.
1607 #ifndef NDEBUG
1608  assert(impl->pendingRootUpdates.erase(op) &&
1609  "operation did not have a pending in-place update");
1610 #endif
1611 }
1612 
1614 #ifndef NDEBUG
1615  assert(impl->pendingRootUpdates.erase(op) &&
1616  "operation did not have a pending in-place update");
1617 #endif
1618  // Erase the last update for this operation.
1619  auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1620  auto &rootUpdates = impl->rootUpdates;
1621  auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1622  assert(it != rootUpdates.rend() && "no root update started on op");
1623  (*it).resetOperation();
1624  int updateIdx = std::prev(rootUpdates.rend()) - it;
1625  rootUpdates.erase(rootUpdates.begin() + updateIdx);
1626 }
1627 
1629  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1630  return impl->notifyMatchFailure(loc, reasonCallback);
1631 }
1632 
1634  return *impl;
1635 }
1636 
1637 //===----------------------------------------------------------------------===//
1638 // ConversionPattern
1639 //===----------------------------------------------------------------------===//
1640 
1643  PatternRewriter &rewriter) const {
1644  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1645  auto &rewriterImpl = dialectRewriter.getImpl();
1646 
1647  // Track the current conversion pattern type converter in the rewriter.
1648  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1649  getTypeConverter());
1650 
1651  // Remap the operands of the operation.
1652  SmallVector<Value, 4> operands;
1653  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1654  op->getOperands(), operands))) {
1655  return failure();
1656  }
1657  return matchAndRewrite(op, operands, dialectRewriter);
1658 }
1659 
1660 //===----------------------------------------------------------------------===//
1661 // OperationLegalizer
1662 //===----------------------------------------------------------------------===//
1663 
1664 namespace {
1665 /// A set of rewrite patterns that can be used to legalize a given operation.
1666 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1667 
1668 /// This class defines a recursive operation legalizer.
1669 class OperationLegalizer {
1670 public:
1671  using LegalizationAction = ConversionTarget::LegalizationAction;
1672 
1673  OperationLegalizer(ConversionTarget &targetInfo,
1674  const FrozenRewritePatternSet &patterns);
1675 
1676  /// Returns true if the given operation is known to be illegal on the target.
1677  bool isIllegal(Operation *op) const;
1678 
1679  /// Attempt to legalize the given operation. Returns success if the operation
1680  /// was legalized, failure otherwise.
1681  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1682 
1683  /// Returns the conversion target in use by the legalizer.
1684  ConversionTarget &getTarget() { return target; }
1685 
1686 private:
1687  /// Attempt to legalize the given operation by folding it.
1688  LogicalResult legalizeWithFold(Operation *op,
1689  ConversionPatternRewriter &rewriter);
1690 
1691  /// Attempt to legalize the given operation by applying a pattern. Returns
1692  /// success if the operation was legalized, failure otherwise.
1693  LogicalResult legalizeWithPattern(Operation *op,
1694  ConversionPatternRewriter &rewriter);
1695 
1696  /// Return true if the given pattern may be applied to the given operation,
1697  /// false otherwise.
1698  bool canApplyPattern(Operation *op, const Pattern &pattern,
1699  ConversionPatternRewriter &rewriter);
1700 
1701  /// Legalize the resultant IR after successfully applying the given pattern.
1702  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1703  ConversionPatternRewriter &rewriter,
1704  RewriterState &curState);
1705 
1706  /// Legalizes the actions registered during the execution of a pattern.
1707  LogicalResult legalizePatternBlockActions(Operation *op,
1708  ConversionPatternRewriter &rewriter,
1710  RewriterState &state,
1711  RewriterState &newState);
1712  LogicalResult legalizePatternCreatedOperations(
1714  RewriterState &state, RewriterState &newState);
1715  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1717  RewriterState &state,
1718  RewriterState &newState);
1719 
1720  //===--------------------------------------------------------------------===//
1721  // Cost Model
1722  //===--------------------------------------------------------------------===//
1723 
1724  /// Build an optimistic legalization graph given the provided patterns. This
1725  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1726  /// patterns for operations that are not directly legal, but may be
1727  /// transitively legal for the current target given the provided patterns.
1728  void buildLegalizationGraph(
1729  LegalizationPatterns &anyOpLegalizerPatterns,
1731 
1732  /// Compute the benefit of each node within the computed legalization graph.
1733  /// This orders the patterns within 'legalizerPatterns' based upon two
1734  /// criteria:
1735  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1736  /// represent the more direct mapping to the target.
1737  /// 2) When comparing patterns with the same legalization depth, prefer the
1738  /// pattern with the highest PatternBenefit. This allows for users to
1739  /// prefer specific legalizations over others.
1740  void computeLegalizationGraphBenefit(
1741  LegalizationPatterns &anyOpLegalizerPatterns,
1743 
1744  /// Compute the legalization depth when legalizing an operation of the given
1745  /// type.
1746  unsigned computeOpLegalizationDepth(
1747  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1749 
1750  /// Apply the conversion cost model to the given set of patterns, and return
1751  /// the smallest legalization depth of any of the patterns. See
1752  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1753  unsigned applyCostModelToPatterns(
1754  LegalizationPatterns &patterns,
1755  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1757 
1758  /// The current set of patterns that have been applied.
1759  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1760 
1761  /// The legalization information provided by the target.
1762  ConversionTarget &target;
1763 
1764  /// The pattern applicator to use for conversions.
1765  PatternApplicator applicator;
1766 };
1767 } // namespace
1768 
1769 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1770  const FrozenRewritePatternSet &patterns)
1771  : target(targetInfo), applicator(patterns) {
1772  // The set of patterns that can be applied to illegal operations to transform
1773  // them into legal ones.
1775  LegalizationPatterns anyOpLegalizerPatterns;
1776 
1777  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1778  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1779 }
1780 
1781 bool OperationLegalizer::isIllegal(Operation *op) const {
1782  return target.isIllegal(op);
1783 }
1784 
1786 OperationLegalizer::legalize(Operation *op,
1787  ConversionPatternRewriter &rewriter) {
1788 #ifndef NDEBUG
1789  const char *logLineComment =
1790  "//===-------------------------------------------===//\n";
1791 
1792  auto &logger = rewriter.getImpl().logger;
1793 #endif
1794  LLVM_DEBUG({
1795  logger.getOStream() << "\n";
1796  logger.startLine() << logLineComment;
1797  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1798  << op << ") {\n";
1799  logger.indent();
1800 
1801  // If the operation has no regions, just print it here.
1802  if (op->getNumRegions() == 0) {
1803  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1804  logger.getOStream() << "\n\n";
1805  }
1806  });
1807 
1808  // Check if this operation is legal on the target.
1809  if (auto legalityInfo = target.isLegal(op)) {
1810  LLVM_DEBUG({
1811  logSuccess(
1812  logger, "operation marked legal by the target{0}",
1813  legalityInfo->isRecursivelyLegal
1814  ? "; NOTE: operation is recursively legal; skipping internals"
1815  : "");
1816  logger.startLine() << logLineComment;
1817  });
1818 
1819  // If this operation is recursively legal, mark its children as ignored so
1820  // that we don't consider them for legalization.
1821  if (legalityInfo->isRecursivelyLegal)
1822  rewriter.getImpl().markNestedOpsIgnored(op);
1823  return success();
1824  }
1825 
1826  // Check to see if the operation is ignored and doesn't need to be converted.
1827  if (rewriter.getImpl().isOpIgnored(op)) {
1828  LLVM_DEBUG({
1829  logSuccess(logger, "operation marked 'ignored' during conversion");
1830  logger.startLine() << logLineComment;
1831  });
1832  return success();
1833  }
1834 
1835  // If the operation isn't legal, try to fold it in-place.
1836  // TODO: Should we always try to do this, even if the op is
1837  // already legal?
1838  if (succeeded(legalizeWithFold(op, rewriter))) {
1839  LLVM_DEBUG({
1840  logSuccess(logger, "operation was folded");
1841  logger.startLine() << logLineComment;
1842  });
1843  return success();
1844  }
1845 
1846  // Otherwise, we need to apply a legalization pattern to this operation.
1847  if (succeeded(legalizeWithPattern(op, rewriter))) {
1848  LLVM_DEBUG({
1849  logSuccess(logger, "");
1850  logger.startLine() << logLineComment;
1851  });
1852  return success();
1853  }
1854 
1855  LLVM_DEBUG({
1856  logFailure(logger, "no matched legalization pattern");
1857  logger.startLine() << logLineComment;
1858  });
1859  return failure();
1860 }
1861 
1863 OperationLegalizer::legalizeWithFold(Operation *op,
1864  ConversionPatternRewriter &rewriter) {
1865  auto &rewriterImpl = rewriter.getImpl();
1866  RewriterState curState = rewriterImpl.getCurrentState();
1867 
1868  LLVM_DEBUG({
1869  rewriterImpl.logger.startLine() << "* Fold {\n";
1870  rewriterImpl.logger.indent();
1871  });
1872 
1873  // Try to fold the operation.
1874  SmallVector<Value, 2> replacementValues;
1875  rewriter.setInsertionPoint(op);
1876  if (failed(rewriter.tryFold(op, replacementValues))) {
1877  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1878  return failure();
1879  }
1880 
1881  // Insert a replacement for 'op' with the folded replacement values.
1882  rewriter.replaceOp(op, replacementValues);
1883 
1884  // Recursively legalize any new constant operations.
1885  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1886  i != e; ++i) {
1887  Operation *cstOp = rewriterImpl.createdOps[i];
1888  if (failed(legalize(cstOp, rewriter))) {
1889  LLVM_DEBUG(logFailure(rewriterImpl.logger,
1890  "failed to legalize generated constant '{0}'",
1891  cstOp->getName()));
1892  rewriterImpl.resetState(curState);
1893  return failure();
1894  }
1895  }
1896 
1897  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1898  return success();
1899 }
1900 
1902 OperationLegalizer::legalizeWithPattern(Operation *op,
1903  ConversionPatternRewriter &rewriter) {
1904  auto &rewriterImpl = rewriter.getImpl();
1905 
1906  // Functor that returns if the given pattern may be applied.
1907  auto canApply = [&](const Pattern &pattern) {
1908  return canApplyPattern(op, pattern, rewriter);
1909  };
1910 
1911  // Functor that cleans up the rewriter state after a pattern failed to match.
1912  RewriterState curState = rewriterImpl.getCurrentState();
1913  auto onFailure = [&](const Pattern &pattern) {
1914  LLVM_DEBUG({
1915  logFailure(rewriterImpl.logger, "pattern failed to match");
1916  if (rewriterImpl.notifyCallback) {
1918  diag << "Failed to apply pattern \"" << pattern.getDebugName()
1919  << "\" on op:\n"
1920  << *op;
1921  rewriterImpl.notifyCallback(diag);
1922  }
1923  });
1924  rewriterImpl.resetState(curState);
1925  appliedPatterns.erase(&pattern);
1926  };
1927 
1928  // Functor that performs additional legalization when a pattern is
1929  // successfully applied.
1930  auto onSuccess = [&](const Pattern &pattern) {
1931  auto result = legalizePatternResult(op, pattern, rewriter, curState);
1932  appliedPatterns.erase(&pattern);
1933  if (failed(result))
1934  rewriterImpl.resetState(curState);
1935  return result;
1936  };
1937 
1938  // Try to match and rewrite a pattern on this operation.
1939  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1940  onSuccess);
1941 }
1942 
1943 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1944  ConversionPatternRewriter &rewriter) {
1945  LLVM_DEBUG({
1946  auto &os = rewriter.getImpl().logger;
1947  os.getOStream() << "\n";
1948  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1949  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
1950  os.getOStream() << ")' {\n";
1951  os.indent();
1952  });
1953 
1954  // Ensure that we don't cycle by not allowing the same pattern to be
1955  // applied twice in the same recursion stack if it is not known to be safe.
1956  if (!pattern.hasBoundedRewriteRecursion() &&
1957  !appliedPatterns.insert(&pattern).second) {
1958  LLVM_DEBUG(
1959  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1960  return false;
1961  }
1962  return true;
1963 }
1964 
1966 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1967  ConversionPatternRewriter &rewriter,
1968  RewriterState &curState) {
1969  auto &impl = rewriter.getImpl();
1970 
1971 #ifndef NDEBUG
1972  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1973 #endif
1974 
1975  // Check that the root was either replaced or updated in place.
1976  auto replacedRoot = [&] {
1977  return llvm::any_of(
1978  llvm::drop_begin(impl.replacements, curState.numReplacements),
1979  [op](auto &it) { return it.first == op; });
1980  };
1981  auto updatedRootInPlace = [&] {
1982  return llvm::any_of(
1983  llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
1984  [op](auto &state) { return state.getOperation() == op; });
1985  };
1986  (void)replacedRoot;
1987  (void)updatedRootInPlace;
1988  assert((replacedRoot() || updatedRootInPlace()) &&
1989  "expected pattern to replace the root operation");
1990 
1991  // Legalize each of the actions registered during application.
1992  RewriterState newState = impl.getCurrentState();
1993  if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
1994  newState)) ||
1995  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
1996  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
1997  newState))) {
1998  return failure();
1999  }
2000 
2001  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2002  return success();
2003 }
2004 
2005 LogicalResult OperationLegalizer::legalizePatternBlockActions(
2006  Operation *op, ConversionPatternRewriter &rewriter,
2007  ConversionPatternRewriterImpl &impl, RewriterState &state,
2008  RewriterState &newState) {
2009  SmallPtrSet<Operation *, 16> operationsToIgnore;
2010 
2011  // If the pattern moved or created any blocks, make sure the types of block
2012  // arguments get legalized.
2013  for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
2014  ++i) {
2015  auto &action = impl.blockActions[i];
2016  if (action.kind == BlockActionKind::TypeConversion ||
2017  action.kind == BlockActionKind::Erase)
2018  continue;
2019  // Only check blocks outside of the current operation.
2020  Operation *parentOp = action.block->getParentOp();
2021  if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
2022  continue;
2023 
2024  // If the region of the block has a type converter, try to convert the block
2025  // directly.
2026  if (auto *converter =
2027  impl.argConverter.getConverter(action.block->getParent())) {
2028  if (failed(impl.convertBlockSignature(action.block, converter))) {
2029  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2030  "block"));
2031  return failure();
2032  }
2033  continue;
2034  }
2035 
2036  // Otherwise, check that this operation isn't one generated by this pattern.
2037  // This is because we will attempt to legalize the parent operation, and
2038  // blocks in regions created by this pattern will already be legalized later
2039  // on. If we haven't built the set yet, build it now.
2040  if (operationsToIgnore.empty()) {
2041  auto createdOps = ArrayRef<Operation *>(impl.createdOps)
2042  .drop_front(state.numCreatedOps);
2043  operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2044  }
2045 
2046  // If this operation should be considered for re-legalization, try it.
2047  if (operationsToIgnore.insert(parentOp).second &&
2048  failed(legalize(parentOp, rewriter))) {
2049  LLVM_DEBUG(logFailure(
2050  impl.logger, "operation '{0}'({1}) became illegal after block action",
2051  parentOp->getName(), parentOp));
2052  return failure();
2053  }
2054  }
2055  return success();
2056 }
2057 
2058 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2060  RewriterState &state, RewriterState &newState) {
2061  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2062  Operation *op = impl.createdOps[i];
2063  if (failed(legalize(op, rewriter))) {
2064  LLVM_DEBUG(logFailure(impl.logger,
2065  "failed to legalize generated operation '{0}'({1})",
2066  op->getName(), op));
2067  return failure();
2068  }
2069  }
2070  return success();
2071 }
2072 
2073 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2075  RewriterState &state, RewriterState &newState) {
2076  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
2077  Operation *op = impl.rootUpdates[i].getOperation();
2078  if (failed(legalize(op, rewriter))) {
2079  LLVM_DEBUG(logFailure(
2080  impl.logger, "failed to legalize operation updated in-place '{0}'",
2081  op->getName()));
2082  return failure();
2083  }
2084  }
2085  return success();
2086 }
2087 
2088 //===----------------------------------------------------------------------===//
2089 // Cost Model
2090 
2091 void OperationLegalizer::buildLegalizationGraph(
2092  LegalizationPatterns &anyOpLegalizerPatterns,
2093  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2094  // A mapping between an operation and a set of operations that can be used to
2095  // generate it.
2097  // A mapping between an operation and any currently invalid patterns it has.
2099  // A worklist of patterns to consider for legality.
2100  SetVector<const Pattern *> patternWorklist;
2101 
2102  // Build the mapping from operations to the parent ops that may generate them.
2103  applicator.walkAllPatterns([&](const Pattern &pattern) {
2104  std::optional<OperationName> root = pattern.getRootKind();
2105 
2106  // If the pattern has no specific root, we can't analyze the relationship
2107  // between the root op and generated operations. Given that, add all such
2108  // patterns to the legalization set.
2109  if (!root) {
2110  anyOpLegalizerPatterns.push_back(&pattern);
2111  return;
2112  }
2113 
2114  // Skip operations that are always known to be legal.
2115  if (target.getOpAction(*root) == LegalizationAction::Legal)
2116  return;
2117 
2118  // Add this pattern to the invalid set for the root op and record this root
2119  // as a parent for any generated operations.
2120  invalidPatterns[*root].insert(&pattern);
2121  for (auto op : pattern.getGeneratedOps())
2122  parentOps[op].insert(*root);
2123 
2124  // Add this pattern to the worklist.
2125  patternWorklist.insert(&pattern);
2126  });
2127 
2128  // If there are any patterns that don't have a specific root kind, we can't
2129  // make direct assumptions about what operations will never be legalized.
2130  // Note: Technically we could, but it would require an analysis that may
2131  // recurse into itself. It would be better to perform this kind of filtering
2132  // at a higher level than here anyways.
2133  if (!anyOpLegalizerPatterns.empty()) {
2134  for (const Pattern *pattern : patternWorklist)
2135  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2136  return;
2137  }
2138 
2139  while (!patternWorklist.empty()) {
2140  auto *pattern = patternWorklist.pop_back_val();
2141 
2142  // Check to see if any of the generated operations are invalid.
2143  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2144  std::optional<LegalizationAction> action = target.getOpAction(op);
2145  return !legalizerPatterns.count(op) &&
2146  (!action || action == LegalizationAction::Illegal);
2147  }))
2148  continue;
2149 
2150  // Otherwise, if all of the generated operation are valid, this op is now
2151  // legal so add all of the child patterns to the worklist.
2152  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2153  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2154 
2155  // Add any invalid patterns of the parent operations to see if they have now
2156  // become legal.
2157  for (auto op : parentOps[*pattern->getRootKind()])
2158  patternWorklist.set_union(invalidPatterns[op]);
2159  }
2160 }
2161 
2162 void OperationLegalizer::computeLegalizationGraphBenefit(
2163  LegalizationPatterns &anyOpLegalizerPatterns,
2164  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2165  // The smallest pattern depth, when legalizing an operation.
2166  DenseMap<OperationName, unsigned> minOpPatternDepth;
2167 
2168  // For each operation that is transitively legal, compute a cost for it.
2169  for (auto &opIt : legalizerPatterns)
2170  if (!minOpPatternDepth.count(opIt.first))
2171  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2172  legalizerPatterns);
2173 
2174  // Apply the cost model to the patterns that can match any operation. Those
2175  // with a specific operation type are already resolved when computing the op
2176  // legalization depth.
2177  if (!anyOpLegalizerPatterns.empty())
2178  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2179  legalizerPatterns);
2180 
2181  // Apply a cost model to the pattern applicator. We order patterns first by
2182  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2183  // decreasing benefit.
2184  applicator.applyCostModel([&](const Pattern &pattern) {
2185  ArrayRef<const Pattern *> orderedPatternList;
2186  if (std::optional<OperationName> rootName = pattern.getRootKind())
2187  orderedPatternList = legalizerPatterns[*rootName];
2188  else
2189  orderedPatternList = anyOpLegalizerPatterns;
2190 
2191  // If the pattern is not found, then it was removed and cannot be matched.
2192  auto *it = llvm::find(orderedPatternList, &pattern);
2193  if (it == orderedPatternList.end())
2195 
2196  // Patterns found earlier in the list have higher benefit.
2197  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2198  });
2199 }
2200 
2201 unsigned OperationLegalizer::computeOpLegalizationDepth(
2202  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2203  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2204  // Check for existing depth.
2205  auto depthIt = minOpPatternDepth.find(op);
2206  if (depthIt != minOpPatternDepth.end())
2207  return depthIt->second;
2208 
2209  // If a mapping for this operation does not exist, then this operation
2210  // is always legal. Return 0 as the depth for a directly legal operation.
2211  auto opPatternsIt = legalizerPatterns.find(op);
2212  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2213  return 0u;
2214 
2215  // Record this initial depth in case we encounter this op again when
2216  // recursively computing the depth.
2217  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2218 
2219  // Apply the cost model to the operation patterns, and update the minimum
2220  // depth.
2221  unsigned minDepth = applyCostModelToPatterns(
2222  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2223  minOpPatternDepth[op] = minDepth;
2224  return minDepth;
2225 }
2226 
2227 unsigned OperationLegalizer::applyCostModelToPatterns(
2228  LegalizationPatterns &patterns,
2229  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2230  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2231  unsigned minDepth = std::numeric_limits<unsigned>::max();
2232 
2233  // Compute the depth for each pattern within the set.
2235  patternsByDepth.reserve(patterns.size());
2236  for (const Pattern *pattern : patterns) {
2237  unsigned depth = 1;
2238  for (auto generatedOp : pattern->getGeneratedOps()) {
2239  unsigned generatedOpDepth = computeOpLegalizationDepth(
2240  generatedOp, minOpPatternDepth, legalizerPatterns);
2241  depth = std::max(depth, generatedOpDepth + 1);
2242  }
2243  patternsByDepth.emplace_back(pattern, depth);
2244 
2245  // Update the minimum depth of the pattern list.
2246  minDepth = std::min(minDepth, depth);
2247  }
2248 
2249  // If the operation only has one legalization pattern, there is no need to
2250  // sort them.
2251  if (patternsByDepth.size() == 1)
2252  return minDepth;
2253 
2254  // Sort the patterns by those likely to be the most beneficial.
2255  std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2256  [](const std::pair<const Pattern *, unsigned> &lhs,
2257  const std::pair<const Pattern *, unsigned> &rhs) {
2258  // First sort by the smaller pattern legalization
2259  // depth.
2260  if (lhs.second != rhs.second)
2261  return lhs.second < rhs.second;
2262 
2263  // Then sort by the larger pattern benefit.
2264  auto lhsBenefit = lhs.first->getBenefit();
2265  auto rhsBenefit = rhs.first->getBenefit();
2266  return lhsBenefit > rhsBenefit;
2267  });
2268 
2269  // Update the legalization pattern to use the new sorted list.
2270  patterns.clear();
2271  for (auto &patternIt : patternsByDepth)
2272  patterns.push_back(patternIt.first);
2273  return minDepth;
2274 }
2275 
2276 //===----------------------------------------------------------------------===//
2277 // OperationConverter
2278 //===----------------------------------------------------------------------===//
2279 namespace {
2280 enum OpConversionMode {
2281  /// In this mode, the conversion will ignore failed conversions to allow
2282  /// illegal operations to co-exist in the IR.
2283  Partial,
2284 
2285  /// In this mode, all operations must be legal for the given target for the
2286  /// conversion to succeed.
2287  Full,
2288 
2289  /// In this mode, operations are analyzed for legality. No actual rewrites are
2290  /// applied to the operations on success.
2291  Analysis,
2292 };
2293 
2294 // This class converts operations to a given conversion target via a set of
2295 // rewrite patterns. The conversion behaves differently depending on the
2296 // conversion mode.
2297 struct OperationConverter {
2298  explicit OperationConverter(ConversionTarget &target,
2299  const FrozenRewritePatternSet &patterns,
2300  OpConversionMode mode,
2301  DenseSet<Operation *> *trackedOps = nullptr)
2302  : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2303 
2304  /// Converts the given operations to the conversion target.
2306  convertOperations(ArrayRef<Operation *> ops,
2307  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
2308 
2309 private:
2310  /// Converts an operation with the given rewriter.
2311  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2312 
2313  /// This method is called after the conversion process to legalize any
2314  /// remaining artifacts and complete the conversion.
2315  LogicalResult finalize(ConversionPatternRewriter &rewriter);
2316 
2317  /// Legalize the types of converted block arguments.
2319  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2320  ConversionPatternRewriterImpl &rewriterImpl);
2321 
2322  /// Legalize any unresolved type materializations.
2323  LogicalResult legalizeUnresolvedMaterializations(
2324  ConversionPatternRewriter &rewriter,
2325  ConversionPatternRewriterImpl &rewriterImpl,
2326  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2327 
2328  /// Legalize an operation result that was marked as "erased".
2330  legalizeErasedResult(Operation *op, OpResult result,
2331  ConversionPatternRewriterImpl &rewriterImpl);
2332 
2333  /// Legalize an operation result that was replaced with a value of a different
2334  /// type.
2335  LogicalResult legalizeChangedResultType(
2336  Operation *op, OpResult result, Value newValue,
2337  TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2338  ConversionPatternRewriterImpl &rewriterImpl,
2339  const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2340 
2341  /// The legalizer to use when converting operations.
2342  OperationLegalizer opLegalizer;
2343 
2344  /// The conversion mode to use when legalizing operations.
2345  OpConversionMode mode;
2346 
2347  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2348  /// this is populated with ops found to be legalizable to the target.
2349  /// When mode == OpConversionMode::Partial, this is populated with ops found
2350  /// *not* to be legalizable to the target.
2351  DenseSet<Operation *> *trackedOps;
2352 };
2353 } // namespace
2354 
2355 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2356  Operation *op) {
2357  // Legalize the given operation.
2358  if (failed(opLegalizer.legalize(op, rewriter))) {
2359  // Handle the case of a failed conversion for each of the different modes.
2360  // Full conversions expect all operations to be converted.
2361  if (mode == OpConversionMode::Full)
2362  return op->emitError()
2363  << "failed to legalize operation '" << op->getName() << "'";
2364  // Partial conversions allow conversions to fail iff the operation was not
2365  // explicitly marked as illegal. If the user provided a nonlegalizableOps
2366  // set, non-legalizable ops are included.
2367  if (mode == OpConversionMode::Partial) {
2368  if (opLegalizer.isIllegal(op))
2369  return op->emitError()
2370  << "failed to legalize operation '" << op->getName()
2371  << "' that was explicitly marked illegal";
2372  if (trackedOps)
2373  trackedOps->insert(op);
2374  }
2375  } else if (mode == OpConversionMode::Analysis) {
2376  // Analysis conversions don't fail if any operations fail to legalize,
2377  // they are only interested in the operations that were successfully
2378  // legalized.
2379  trackedOps->insert(op);
2380  }
2381  return success();
2382 }
2383 
2384 LogicalResult OperationConverter::convertOperations(
2386  function_ref<void(Diagnostic &)> notifyCallback) {
2387  if (ops.empty())
2388  return success();
2389  ConversionTarget &target = opLegalizer.getTarget();
2390 
2391  // Compute the set of operations and blocks to convert.
2392  SmallVector<Operation *> toConvert;
2393  for (auto *op : ops) {
2395  [&](Operation *op) {
2396  toConvert.push_back(op);
2397  // Don't check this operation's children for conversion if the
2398  // operation is recursively legal.
2399  auto legalityInfo = target.isLegal(op);
2400  if (legalityInfo && legalityInfo->isRecursivelyLegal)
2401  return WalkResult::skip();
2402  return WalkResult::advance();
2403  });
2404  }
2405 
2406  // Convert each operation and discard rewrites on failure.
2407  ConversionPatternRewriter rewriter(ops.front()->getContext());
2408  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2409  rewriterImpl.notifyCallback = notifyCallback;
2410 
2411  for (auto *op : toConvert)
2412  if (failed(convert(rewriter, op)))
2413  return rewriterImpl.discardRewrites(), failure();
2414 
2415  // Now that all of the operations have been converted, finalize the conversion
2416  // process to ensure any lingering conversion artifacts are cleaned up and
2417  // legalized.
2418  if (failed(finalize(rewriter)))
2419  return rewriterImpl.discardRewrites(), failure();
2420 
2421  // After a successful conversion, apply rewrites if this is not an analysis
2422  // conversion.
2423  if (mode == OpConversionMode::Analysis) {
2424  rewriterImpl.discardRewrites();
2425  } else {
2426  rewriterImpl.applyRewrites();
2427 
2428  // It is possible for a later pattern to erase an op that was originally
2429  // identified as illegal and added to the trackedOps, remove it now after
2430  // replacements have been computed.
2431  if (trackedOps)
2432  for (auto &repl : rewriterImpl.replacements)
2433  trackedOps->erase(repl.first);
2434  }
2435  return success();
2436 }
2437 
2439 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2440  std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2441  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2442  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2443  inverseMapping)) ||
2444  failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2445  return failure();
2446 
2447  if (rewriterImpl.operationsWithChangedResults.empty())
2448  return success();
2449 
2450  // Process requested operation replacements.
2451  for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2452  i != e; ++i) {
2453  unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2454  auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2455  for (OpResult result : repl.first->getResults()) {
2456  Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2457 
2458  // If the operation result was replaced with null, all of the uses of this
2459  // value should be replaced.
2460  if (!newValue) {
2461  if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2462  return failure();
2463  continue;
2464  }
2465 
2466  // Otherwise, check to see if the type of the result changed.
2467  if (result.getType() == newValue.getType())
2468  continue;
2469 
2470  // Compute the inverse mapping only if it is really needed.
2471  if (!inverseMapping)
2472  inverseMapping = rewriterImpl.mapping.getInverse();
2473 
2474  // Legalize this result.
2475  rewriter.setInsertionPoint(repl.first);
2476  if (failed(legalizeChangedResultType(repl.first, result, newValue,
2477  repl.second.converter, rewriter,
2478  rewriterImpl, *inverseMapping)))
2479  return failure();
2480 
2481  // Update the end iterator for this loop in the case it was updated
2482  // when legalizing generated conversion operations.
2483  e = rewriterImpl.operationsWithChangedResults.size();
2484  }
2485  }
2486  return success();
2487 }
2488 
2489 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2490  ConversionPatternRewriter &rewriter,
2491  ConversionPatternRewriterImpl &rewriterImpl) {
2492  // Functor used to check if all users of a value will be dead after
2493  // conversion.
2494  auto findLiveUser = [&](Value val) {
2495  auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2496  return rewriterImpl.isOpIgnored(user);
2497  });
2498  return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2499  };
2500  return rewriterImpl.argConverter.materializeLiveConversions(
2501  rewriterImpl.mapping, rewriter, findLiveUser);
2502 }
2503 
2504 /// Replace the results of a materialization operation with the given values.
2505 static void
2507  ResultRange matResults, ValueRange values,
2508  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2509  matResults.replaceAllUsesWith(values);
2510 
2511  // For each of the materialization results, update the inverse mappings to
2512  // point to the replacement values.
2513  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
2514  auto inverseMapIt = inverseMapping.find(matResult);
2515  if (inverseMapIt == inverseMapping.end())
2516  continue;
2517 
2518  // Update the reverse mapping, or remove the mapping if we couldn't update
2519  // it. Not being able to update signals that the mapping would have become
2520  // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2521  // propagated through temporary materializations. We simply drop the
2522  // mapping, and let the post-conversion replacement logic handle updating
2523  // uses.
2524  for (Value inverseMapVal : inverseMapIt->second)
2525  if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
2526  rewriterImpl.mapping.erase(inverseMapVal);
2527  }
2528 }
2529 
2530 /// Compute all of the unresolved materializations that will persist beyond the
2531 /// conversion process, and require inserting a proper user materialization for.
2534  ConversionPatternRewriter &rewriter,
2535  ConversionPatternRewriterImpl &rewriterImpl,
2536  DenseMap<Value, SmallVector<Value>> &inverseMapping,
2537  SetVector<UnresolvedMaterialization *> &necessaryMaterializations) {
2538  auto isLive = [&](Value value) {
2539  auto findFn = [&](Operation *user) {
2540  auto matIt = materializationOps.find(user);
2541  if (matIt != materializationOps.end())
2542  return !necessaryMaterializations.count(matIt->second);
2543  return rewriterImpl.isOpIgnored(user);
2544  };
2545  // This value may be replacing another value that has a live user.
2546  for (Value inv : inverseMapping.lookup(value))
2547  if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2548  return true;
2549  // Or have live users itself.
2550  return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2551  };
2552 
2553  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2554  [&](Value invalidRoot, Value value, Type type) {
2555  // Check to see if the input operation was remapped to a variant of the
2556  // output.
2557  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2558  if (remappedValue.getType() == type && remappedValue != invalidRoot)
2559  return remappedValue;
2560 
2561  // Check to see if the input is a materialization operation that
2562  // provides an inverse conversion. We just check blindly for
2563  // UnrealizedConversionCastOp here, but it has no effect on correctness.
2564  auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2565  if (inputCastOp && inputCastOp->getNumOperands() == 1)
2566  return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2567  type);
2568 
2569  return Value();
2570  };
2571 
2573  for (auto &mat : rewriterImpl.unresolvedMaterializations) {
2574  materializationOps.try_emplace(mat.getOp(), &mat);
2575  worklist.insert(&mat);
2576  }
2577  while (!worklist.empty()) {
2578  UnresolvedMaterialization *mat = worklist.pop_back_val();
2579  UnrealizedConversionCastOp op = mat->getOp();
2580 
2581  // We currently only handle target materializations here.
2582  assert(op->getNumResults() == 1 && "unexpected materialization type");
2583  OpResult opResult = op->getOpResult(0);
2584  Type outputType = opResult.getType();
2585  Operation::operand_range inputOperands = op.getOperands();
2586 
2587  // Try to forward propagate operands for user conversion casts that result
2588  // in the input types of the current cast.
2589  for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2590  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2591  if (!castOp)
2592  continue;
2593  if (castOp->getResultTypes() == inputOperands.getTypes()) {
2594  replaceMaterialization(rewriterImpl, opResult, inputOperands,
2595  inverseMapping);
2596  necessaryMaterializations.remove(materializationOps.lookup(user));
2597  }
2598  }
2599 
2600  // Try to avoid materializing a resolved materialization if possible.
2601  // Handle the case of a 1-1 materialization.
2602  if (inputOperands.size() == 1) {
2603  // Check to see if the input operation was remapped to a variant of the
2604  // output.
2605  Value remappedValue =
2606  lookupRemappedValue(opResult, inputOperands[0], outputType);
2607  if (remappedValue && remappedValue != opResult) {
2608  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2609  inverseMapping);
2610  necessaryMaterializations.remove(mat);
2611  continue;
2612  }
2613  } else {
2614  // TODO: Avoid materializing other types of conversions here.
2615  }
2616 
2617  // Check to see if this is an argument materialization.
2618  auto isBlockArg = [](Value v) { return v.isa<BlockArgument>(); };
2619  if (llvm::any_of(op->getOperands(), isBlockArg) ||
2620  llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
2622  }
2623 
2624  // If the materialization does not have any live users, we don't need to
2625  // generate a user materialization for it.
2626  // FIXME: For argument materializations, we currently need to check if any
2627  // of the inverse mapped values are used because some patterns expect blind
2628  // value replacement even if the types differ in some cases. When those
2629  // patterns are fixed, we can drop the argument special case here.
2630  bool isMaterializationLive = isLive(opResult);
2631  if (mat->getKind() == UnresolvedMaterialization::Argument)
2632  isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2633  if (!isMaterializationLive)
2634  continue;
2635  if (!necessaryMaterializations.insert(mat))
2636  continue;
2637 
2638  // Reprocess input materializations to see if they have an updated status.
2639  for (Value input : inputOperands) {
2640  if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2641  if (auto *mat = materializationOps.lookup(parentOp))
2642  worklist.insert(mat);
2643  }
2644  }
2645  }
2646 }
2647 
2648 /// Legalize the given unresolved materialization. Returns success if the
2649 /// materialization was legalized, failure otherise.
2651  UnresolvedMaterialization &mat,
2653  ConversionPatternRewriter &rewriter,
2654  ConversionPatternRewriterImpl &rewriterImpl,
2655  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2656  auto findLiveUser = [&](auto &&users) {
2657  auto liveUserIt = llvm::find_if_not(
2658  users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
2659  return liveUserIt == users.end() ? nullptr : *liveUserIt;
2660  };
2661 
2662  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2663  [&](Value value, Type type) {
2664  // Check to see if the input operation was remapped to a variant of the
2665  // output.
2666  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2667  if (remappedValue.getType() == type)
2668  return remappedValue;
2669  return Value();
2670  };
2671 
2672  UnrealizedConversionCastOp op = mat.getOp();
2673  if (!rewriterImpl.ignoredOps.insert(op))
2674  return success();
2675 
2676  // We currently only handle target materializations here.
2677  OpResult opResult = op->getOpResult(0);
2678  Operation::operand_range inputOperands = op.getOperands();
2679  Type outputType = opResult.getType();
2680 
2681  // If any input to this materialization is another materialization, resolve
2682  // the input first.
2683  for (Value value : op->getOperands()) {
2684  auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2685  if (!valueCast)
2686  continue;
2687 
2688  auto matIt = materializationOps.find(valueCast);
2689  if (matIt != materializationOps.end())
2691  *matIt->second, materializationOps, rewriter, rewriterImpl,
2692  inverseMapping)))
2693  return failure();
2694  }
2695 
2696  // Perform a last ditch attempt to avoid materializing a resolved
2697  // materialization if possible.
2698  // Handle the case of a 1-1 materialization.
2699  if (inputOperands.size() == 1) {
2700  // Check to see if the input operation was remapped to a variant of the
2701  // output.
2702  Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2703  if (remappedValue && remappedValue != opResult) {
2704  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2705  inverseMapping);
2706  return success();
2707  }
2708  } else {
2709  // TODO: Avoid materializing other types of conversions here.
2710  }
2711 
2712  // Try to materialize the conversion.
2713  if (TypeConverter *converter = mat.getConverter()) {
2714  // FIXME: Determine a suitable insertion location when there are multiple
2715  // inputs.
2716  if (inputOperands.size() == 1)
2717  rewriter.setInsertionPointAfterValue(inputOperands.front());
2718  else
2719  rewriter.setInsertionPoint(op);
2720 
2721  Value newMaterialization;
2722  switch (mat.getKind()) {
2724  // Try to materialize an argument conversion.
2725  // FIXME: The current argument materialization hook expects the original
2726  // output type, even though it doesn't use that as the actual output type
2727  // of the generated IR. The output type is just used as an indicator of
2728  // the type of materialization to do. This behavior is really awkward in
2729  // that it diverges from the behavior of the other hooks, and can be
2730  // easily misunderstood. We should clean up the argument hooks to better
2731  // represent the desired invariants we actually care about.
2732  newMaterialization = converter->materializeArgumentConversion(
2733  rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2734  if (newMaterialization)
2735  break;
2736 
2737  // If an argument materialization failed, fallback to trying a target
2738  // materialization.
2739  [[fallthrough]];
2740  case UnresolvedMaterialization::Target:
2741  newMaterialization = converter->materializeTargetConversion(
2742  rewriter, op->getLoc(), outputType, inputOperands);
2743  break;
2744  }
2745  if (newMaterialization) {
2746  replaceMaterialization(rewriterImpl, opResult, newMaterialization,
2747  inverseMapping);
2748  return success();
2749  }
2750  }
2751 
2752  InFlightDiagnostic diag = op->emitError()
2753  << "failed to legalize unresolved materialization "
2754  "from "
2755  << inputOperands.getTypes() << " to " << outputType
2756  << " that remained live after conversion";
2757  if (Operation *liveUser = findLiveUser(op->getUsers())) {
2758  diag.attachNote(liveUser->getLoc())
2759  << "see existing live user here: " << *liveUser;
2760  }
2761  return failure();
2762 }
2763 
2764 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2765  ConversionPatternRewriter &rewriter,
2766  ConversionPatternRewriterImpl &rewriterImpl,
2767  std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2768  if (rewriterImpl.unresolvedMaterializations.empty())
2769  return success();
2770  inverseMapping = rewriterImpl.mapping.getInverse();
2771 
2772  // As an initial step, compute all of the inserted materializations that we
2773  // expect to persist beyond the conversion process.
2775  SetVector<UnresolvedMaterialization *> necessaryMaterializations;
2776  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2777  *inverseMapping, necessaryMaterializations);
2778 
2779  // Once computed, legalize any necessary materializations.
2780  for (auto *mat : necessaryMaterializations) {
2782  *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2783  return failure();
2784  }
2785  return success();
2786 }
2787 
2788 LogicalResult OperationConverter::legalizeErasedResult(
2789  Operation *op, OpResult result,
2790  ConversionPatternRewriterImpl &rewriterImpl) {
2791  // If the operation result was replaced with null, all of the uses of this
2792  // value should be replaced.
2793  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2794  return rewriterImpl.isOpIgnored(user);
2795  });
2796  if (liveUserIt != result.user_end()) {
2797  InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2798  << op->getName() << "' marked as erased";
2799  diag.attachNote(liveUserIt->getLoc())
2800  << "found live user of result #" << result.getResultNumber() << ": "
2801  << *liveUserIt;
2802  return failure();
2803  }
2804  return success();
2805 }
2806 
2807 /// Finds a user of the given value, or of any other value that the given value
2808 /// replaced, that was not replaced in the conversion process.
2810  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2811  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2812  SmallVector<Value> worklist(1, initialValue);
2813  while (!worklist.empty()) {
2814  Value value = worklist.pop_back_val();
2815 
2816  // Walk the users of this value to see if there are any live users that
2817  // weren't replaced during conversion.
2818  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2819  return rewriterImpl.isOpIgnored(user);
2820  });
2821  if (liveUserIt != value.user_end())
2822  return *liveUserIt;
2823  auto mapIt = inverseMapping.find(value);
2824  if (mapIt != inverseMapping.end())
2825  worklist.append(mapIt->second);
2826  }
2827  return nullptr;
2828 }
2829 
2830 LogicalResult OperationConverter::legalizeChangedResultType(
2831  Operation *op, OpResult result, Value newValue,
2832  TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2833  ConversionPatternRewriterImpl &rewriterImpl,
2834  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2835  Operation *liveUser =
2836  findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2837  if (!liveUser)
2838  return success();
2839 
2840  // Functor used to emit a conversion error for a failed materialization.
2841  auto emitConversionError = [&] {
2843  << "failed to materialize conversion for result #"
2844  << result.getResultNumber() << " of operation '"
2845  << op->getName()
2846  << "' that remained live after conversion";
2847  diag.attachNote(liveUser->getLoc())
2848  << "see existing live user here: " << *liveUser;
2849  return failure();
2850  };
2851 
2852  // If the replacement has a type converter, attempt to materialize a
2853  // conversion back to the original type.
2854  if (!replConverter)
2855  return emitConversionError();
2856 
2857  // Materialize a conversion for this live result value.
2858  Type resultType = result.getType();
2859  Value convertedValue = replConverter->materializeSourceConversion(
2860  rewriter, op->getLoc(), resultType, newValue);
2861  if (!convertedValue)
2862  return emitConversionError();
2863 
2864  rewriterImpl.mapping.map(result, convertedValue);
2865  return success();
2866 }
2867 
2868 //===----------------------------------------------------------------------===//
2869 // Type Conversion
2870 //===----------------------------------------------------------------------===//
2871 
2873  ArrayRef<Type> types) {
2874  assert(!types.empty() && "expected valid types");
2875  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2876  addInputs(types);
2877 }
2878 
2880  assert(!types.empty() &&
2881  "1->0 type remappings don't need to be added explicitly");
2882  argTypes.append(types.begin(), types.end());
2883 }
2884 
2885 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2886  unsigned newInputNo,
2887  unsigned newInputCount) {
2888  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2889  assert(newInputCount != 0 && "expected valid input count");
2890  remappedInputs[origInputNo] =
2891  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2892 }
2893 
2895  Value replacementValue) {
2896  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2897  remappedInputs[origInputNo] =
2898  InputMapping{origInputNo, /*size=*/0, replacementValue};
2899 }
2900 
2902  SmallVectorImpl<Type> &results) {
2903  auto existingIt = cachedDirectConversions.find(t);
2904  if (existingIt != cachedDirectConversions.end()) {
2905  if (existingIt->second)
2906  results.push_back(existingIt->second);
2907  return success(existingIt->second != nullptr);
2908  }
2909  auto multiIt = cachedMultiConversions.find(t);
2910  if (multiIt != cachedMultiConversions.end()) {
2911  results.append(multiIt->second.begin(), multiIt->second.end());
2912  return success();
2913  }
2914 
2915  // Walk the added converters in reverse order to apply the most recently
2916  // registered first.
2917  size_t currentCount = results.size();
2918  conversionCallStack.push_back(t);
2919  auto popConversionCallStack =
2920  llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
2921  for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2922  if (std::optional<LogicalResult> result =
2923  converter(t, results, conversionCallStack)) {
2924  if (!succeeded(*result)) {
2925  cachedDirectConversions.try_emplace(t, nullptr);
2926  return failure();
2927  }
2928  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2929  if (newTypes.size() == 1)
2930  cachedDirectConversions.try_emplace(t, newTypes.front());
2931  else
2932  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2933  return success();
2934  }
2935  }
2936  return failure();
2937 }
2938 
2940  // Use the multi-type result version to convert the type.
2941  SmallVector<Type, 1> results;
2942  if (failed(convertType(t, results)))
2943  return nullptr;
2944 
2945  // Check to ensure that only one type was produced.
2946  return results.size() == 1 ? results.front() : nullptr;
2947 }
2948 
2950  SmallVectorImpl<Type> &results) {
2951  for (Type type : types)
2952  if (failed(convertType(type, results)))
2953  return failure();
2954  return success();
2955 }
2956 
2957 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2959  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2960 }
2961 
2963  return llvm::all_of(*region, [this](Block &block) {
2964  return isLegal(block.getArgumentTypes());
2965  });
2966 }
2967 
2968 bool TypeConverter::isSignatureLegal(FunctionType ty) {
2969  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2970 }
2971 
2973  SignatureConversion &result) {
2974  // Try to convert the given input type.
2975  SmallVector<Type, 1> convertedTypes;
2976  if (failed(convertType(type, convertedTypes)))
2977  return failure();
2978 
2979  // If this argument is being dropped, there is nothing left to do.
2980  if (convertedTypes.empty())
2981  return success();
2982 
2983  // Otherwise, add the new inputs.
2984  result.addInputs(inputNo, convertedTypes);
2985  return success();
2986 }
2988  SignatureConversion &result,
2989  unsigned origInputOffset) {
2990  for (unsigned i = 0, e = types.size(); i != e; ++i)
2991  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2992  return failure();
2993  return success();
2994 }
2995 
2996 Value TypeConverter::materializeConversion(
2998  OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
2999  for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
3000  if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3001  return *result;
3002  return nullptr;
3003 }
3004 
3006  -> std::optional<SignatureConversion> {
3007  SignatureConversion conversion(block->getNumArguments());
3008  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3009  return std::nullopt;
3010  return conversion;
3011 }
3012 
3013 //===----------------------------------------------------------------------===//
3014 // Type attribute conversion
3015 //===----------------------------------------------------------------------===//
3018  return AttributeConversionResult(attr, resultTag);
3019 }
3020 
3023  return AttributeConversionResult(nullptr, naTag);
3024 }
3025 
3028  return AttributeConversionResult(nullptr, abortTag);
3029 }
3030 
3032  return impl.getInt() == resultTag;
3033 }
3034 
3036  return impl.getInt() == naTag;
3037 }
3038 
3040  return impl.getInt() == abortTag;
3041 }
3042 
3044  assert(hasResult() && "Cannot get result from N/A or abort");
3045  return impl.getPointer();
3046 }
3047 
3048 std::optional<Attribute> TypeConverter::convertTypeAttribute(Type type,
3049  Attribute attr) {
3050  for (TypeAttributeConversionCallbackFn &fn :
3051  llvm::reverse(typeAttributeConversions)) {
3052  AttributeConversionResult res = fn(type, attr);
3053  if (res.hasResult())
3054  return res.getResult();
3055  if (res.isAbort())
3056  return std::nullopt;
3057  }
3058  return std::nullopt;
3059 }
3060 
3061 //===----------------------------------------------------------------------===//
3062 // FunctionOpInterfaceSignatureConversion
3063 //===----------------------------------------------------------------------===//
3064 
3065 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3066  TypeConverter &typeConverter,
3067  ConversionPatternRewriter &rewriter) {
3068  FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
3069 
3070  // Convert the original function types.
3071  TypeConverter::SignatureConversion result(type.getNumInputs());
3072  SmallVector<Type, 1> newResults;
3073  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3074  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3075  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3076  typeConverter, &result)))
3077  return failure();
3078 
3079  // Update the function signature in-place.
3080  auto newType = FunctionType::get(rewriter.getContext(),
3081  result.getConvertedTypes(), newResults);
3082 
3083  rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
3084 
3085  return success();
3086 }
3087 
3088 /// Create a default conversion pattern that rewrites the type signature of a
3089 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3090 /// represent their type.
3091 namespace {
3092 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3093  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3094  MLIRContext *ctx,
3095  TypeConverter &converter)
3096  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3097 
3099  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3100  ConversionPatternRewriter &rewriter) const override {
3101  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3102  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3103  }
3104 };
3105 
3106 struct AnyFunctionOpInterfaceSignatureConversion
3107  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3109 
3111  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3112  ConversionPatternRewriter &rewriter) const override {
3113  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3114  }
3115 };
3116 } // namespace
3117 
3119  StringRef functionLikeOpName, RewritePatternSet &patterns,
3120  TypeConverter &converter) {
3121  patterns.add<FunctionOpInterfaceSignatureConversion>(
3122  functionLikeOpName, patterns.getContext(), converter);
3123 }
3124 
3126  RewritePatternSet &patterns, TypeConverter &converter) {
3127  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3128  converter, patterns.getContext());
3129 }
3130 
3131 //===----------------------------------------------------------------------===//
3132 // ConversionTarget
3133 //===----------------------------------------------------------------------===//
3134 
3136  LegalizationAction action) {
3137  legalOperations[op].action = action;
3138 }
3139 
3141  LegalizationAction action) {
3142  for (StringRef dialect : dialectNames)
3143  legalDialects[dialect] = action;
3144 }
3145 
3147  -> std::optional<LegalizationAction> {
3148  std::optional<LegalizationInfo> info = getOpInfo(op);
3149  return info ? info->action : std::optional<LegalizationAction>();
3150 }
3151 
3153  -> std::optional<LegalOpDetails> {
3154  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3155  if (!info)
3156  return std::nullopt;
3157 
3158  // Returns true if this operation instance is known to be legal.
3159  auto isOpLegal = [&] {
3160  // Handle dynamic legality either with the provided legality function.
3161  if (info->action == LegalizationAction::Dynamic) {
3162  std::optional<bool> result = info->legalityFn(op);
3163  if (result)
3164  return *result;
3165  }
3166 
3167  // Otherwise, the operation is only legal if it was marked 'Legal'.
3168  return info->action == LegalizationAction::Legal;
3169  };
3170  if (!isOpLegal())
3171  return std::nullopt;
3172 
3173  // This operation is legal, compute any additional legality information.
3174  LegalOpDetails legalityDetails;
3175  if (info->isRecursivelyLegal) {
3176  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3177  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3178  legalityDetails.isRecursivelyLegal =
3179  legalityFnIt->second(op).value_or(true);
3180  } else {
3181  legalityDetails.isRecursivelyLegal = true;
3182  }
3183  }
3184  return legalityDetails;
3185 }
3186 
3188  std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3189  if (!info)
3190  return false;
3191 
3192  if (info->action == LegalizationAction::Dynamic) {
3193  std::optional<bool> result = info->legalityFn(op);
3194  if (!result)
3195  return false;
3196 
3197  return !(*result);
3198  }
3199 
3200  return info->action == LegalizationAction::Illegal;
3201 }
3202 
3206  if (!oldCallback)
3207  return newCallback;
3208 
3209  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3210  Operation *op) -> std::optional<bool> {
3211  if (std::optional<bool> result = newCl(op))
3212  return *result;
3213 
3214  return oldCl(op);
3215  };
3216  return chain;
3217 }
3218 
3219 void ConversionTarget::setLegalityCallback(
3220  OperationName name, const DynamicLegalityCallbackFn &callback) {
3221  assert(callback && "expected valid legality callback");
3222  auto infoIt = legalOperations.find(name);
3223  assert(infoIt != legalOperations.end() &&
3224  infoIt->second.action == LegalizationAction::Dynamic &&
3225  "expected operation to already be marked as dynamically legal");
3226  infoIt->second.legalityFn =
3227  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3228 }
3229 
3231  OperationName name, const DynamicLegalityCallbackFn &callback) {
3232  auto infoIt = legalOperations.find(name);
3233  assert(infoIt != legalOperations.end() &&
3234  infoIt->second.action != LegalizationAction::Illegal &&
3235  "expected operation to already be marked as legal");
3236  infoIt->second.isRecursivelyLegal = true;
3237  if (callback)
3238  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3239  std::move(opRecursiveLegalityFns[name]), callback);
3240  else
3241  opRecursiveLegalityFns.erase(name);
3242 }
3243 
3244 void ConversionTarget::setLegalityCallback(
3245  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3246  assert(callback && "expected valid legality callback");
3247  for (StringRef dialect : dialects)
3248  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3249  std::move(dialectLegalityFns[dialect]), callback);
3250 }
3251 
3252 void ConversionTarget::setLegalityCallback(
3253  const DynamicLegalityCallbackFn &callback) {
3254  assert(callback && "expected valid legality callback");
3255  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3256 }
3257 
3258 auto ConversionTarget::getOpInfo(OperationName op) const
3259  -> std::optional<LegalizationInfo> {
3260  // Check for info for this specific operation.
3261  auto it = legalOperations.find(op);
3262  if (it != legalOperations.end())
3263  return it->second;
3264  // Check for info for the parent dialect.
3265  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3266  if (dialectIt != legalDialects.end()) {
3267  DynamicLegalityCallbackFn callback;
3268  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3269  if (dialectFn != dialectLegalityFns.end())
3270  callback = dialectFn->second;
3271  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3272  callback};
3273  }
3274  // Otherwise, check if we mark unknown operations as dynamic.
3275  if (unknownLegalityFn)
3276  return LegalizationInfo{LegalizationAction::Dynamic,
3277  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3278  return std::nullopt;
3279 }
3280 
3281 //===----------------------------------------------------------------------===//
3282 // PDL Configuration
3283 //===----------------------------------------------------------------------===//
3284 
3286  auto &rewriterImpl =
3287  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3288  rewriterImpl.currentTypeConverter = getTypeConverter();
3289 }
3290 
3292  auto &rewriterImpl =
3293  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3294  rewriterImpl.currentTypeConverter = nullptr;
3295 }
3296 
3297 /// Remap the given value using the rewriter and the type converter in the
3298 /// provided config.
3301  SmallVector<Value> mappedValues;
3302  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3303  return failure();
3304  return std::move(mappedValues);
3305 }
3306 
3309  "convertValue",
3310  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3311  auto results = pdllConvertValues(
3312  static_cast<ConversionPatternRewriter &>(rewriter), value);
3313  if (failed(results))
3314  return failure();
3315  return results->front();
3316  });
3318  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3319  return pdllConvertValues(
3320  static_cast<ConversionPatternRewriter &>(rewriter), values);
3321  });
3323  "convertType",
3324  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3325  auto &rewriterImpl =
3326  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3327  if (TypeConverter *converter = rewriterImpl.currentTypeConverter) {
3328  if (Type newType = converter->convertType(type))
3329  return newType;
3330  return failure();
3331  }
3332  return type;
3333  });
3335  "convertTypes",
3336  [](PatternRewriter &rewriter,
3338  auto &rewriterImpl =
3339  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3340  TypeConverter *converter = rewriterImpl.currentTypeConverter;
3341  if (!converter)
3342  return SmallVector<Type>(types);
3343 
3344  SmallVector<Type> remappedTypes;
3345  if (failed(converter->convertTypes(types, remappedTypes)))
3346  return failure();
3347  return std::move(remappedTypes);
3348  });
3349 }
3350 
3351 //===----------------------------------------------------------------------===//
3352 // Op Conversion Entry Points
3353 //===----------------------------------------------------------------------===//
3354 
3355 //===----------------------------------------------------------------------===//
3356 // Partial Conversion
3357 
3360  ConversionTarget &target,
3361  const FrozenRewritePatternSet &patterns,
3362  DenseSet<Operation *> *unconvertedOps) {
3363  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3364  unconvertedOps);
3365  return opConverter.convertOperations(ops);
3366 }
3369  const FrozenRewritePatternSet &patterns,
3370  DenseSet<Operation *> *unconvertedOps) {
3371  return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
3372  unconvertedOps);
3373 }
3374 
3375 //===----------------------------------------------------------------------===//
3376 // Full Conversion
3377 
3380  const FrozenRewritePatternSet &patterns) {
3381  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3382  return opConverter.convertOperations(ops);
3383 }
3386  const FrozenRewritePatternSet &patterns) {
3387  return applyFullConversion(llvm::ArrayRef(op), target, patterns);
3388 }
3389 
3390 //===----------------------------------------------------------------------===//
3391 // Analysis Conversion
3392 
3395  ConversionTarget &target,
3396  const FrozenRewritePatternSet &patterns,
3397  DenseSet<Operation *> &convertedOps,
3398  function_ref<void(Diagnostic &)> notifyCallback) {
3399  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3400  &convertedOps);
3401  return opConverter.convertOperations(ops, notifyCallback);
3402 }
3405  const FrozenRewritePatternSet &patterns,
3406  DenseSet<Operation *> &convertedOps,
3407  function_ref<void(Diagnostic &)> notifyCallback) {
3408  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
3409  convertedOps, notifyCallback);
3410 }
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void detachNestedAndErase(Operation *op)
Detach any operations nested in the given operation from their parent blocks, and erase the given ope...
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterialization * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static Value buildUnresolvedMaterialization(UnresolvedMaterialization::Kind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
Build an unresolved materialization operation given an output type and set of input operands.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterialization &mat, DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter, Location loc, ValueRange inputs, Type origOutputType, Type outputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:304
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:313
Location getLoc() const
Return the location for this argument.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
bool empty()
Definition: Block.h:137
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
Operation & back()
Definition: Block.h:141
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:148
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:291
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:82
RetT walk(FnT &&callback)
Walk the operations in this block.
Definition: Block.h:273
OpListType & getOperations()
Definition: Block.h:126
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
iterator end()
Definition: Block.h:133
iterator begin()
Definition: Block.h:132
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:47
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor) override
PatternRewriter hook for replacing the results of an operation when the given functor returns true.
void notifyBlockCreated(Block *block) override
PatternRewriter hook creating a new block.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void notifyOperationInserted(Operation *op) override
PatternRewriter hook for inserting a new operation.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping) override
PatternRewriter hook for cloning blocks of one region into another.
Base class for the conversion patterns.
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
Definition: UseDefLists.h:227
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:179
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:188
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Location objects represent source locations information in MLIR.
Definition: Location.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:426
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:402
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within 'results'.
Definition: Builders.cpp:450
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:423
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:255
This is a value defined by a result of an operation.
Definition: Value.h:442
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:454
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:210
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:677
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:274
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:640
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:537
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:218
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:197
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:540
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
operand_type_range getOperandTypes()
Definition: Operation.h:376
result_type_range getResultTypes()
Definition: Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
result_range getResults()
Definition: Operation.h:394
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:427
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:42
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:128
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:93
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:89
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:231
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:270
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:549
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
bool isSignatureLegal(FunctionType ty)
Return true if the inputs and outputs of the given function type are legal.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result)
This method allows for converting a specific argument of a signature.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling 'results' as necessary.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0)
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr)
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block)
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:207
bool isa() const
Definition: Value.h:98
void dropAllUses() const
Drop all uses of this object from their respective owners.
Definition: Value.h:161
Type getType() const
Return the type of this value.
Definition: Value.h:122
U dyn_cast() const
Definition: Value.h:103
user_iterator user_end() const
Definition: Value.h:216
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
U cast() const
Definition: Value.h:113
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:166
user_range getUsers() const
Definition: Value.h:217
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::PointerUnion< NamedAttribute *, NamedTypeConstraint * > Argument
Definition: Argument.h:62
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
Definition: Iterators.h:48
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This struct represents a range of new types or a single value that remaps an existing signature input...
void eraseDanglingBlocks()
Erase any blocks that were unlinked from their regions and stored in block actions.
TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void undoBlockActions(unsigned numActionsToKeep=0)
Undo the block actions (motions, splits) one by one in reverse order until "numActionsToKeep" actions...
void discardRewrites()
Cleanup and destroy any generated rewrite operations.
ConversionPatternRewriterImpl(PatternRewriter &rewriter)
function_ref< void(Diagnostic &)> notifyCallback
This allows the user to collect the match failure message.
SmallVector< BlockAction, 4 > blockActions
Ordered list of block operations (creations, splits, motions).
llvm::MapVector< Operation *, OpReplacement > replacements
Ordered map of requested operation replacements.
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent, Region::iterator before)
Notifies that the blocks of a region are about to be moved.
void notifySplitBlock(Block *block, Block *continuation)
Notifies that a block was split.
void applyRewrites()
Apply all requested operation rewrites.
SmallVector< OperationTransactionState, 4 > rootUpdates
A transaction state for each of operations that were updated in-place.
RewriterState getCurrentState()
Return the current state of the rewriter.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
void notifyOpReplaced(Operation *op, ValueRange newValues)
PatternRewriter hook for replacing the results of an operation.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
SmallVector< UnresolvedMaterialization > unresolvedMaterializations
Ordered vector of all unresolved type conversion materializations during conversion.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
ArgConverter argConverter
Utility used to convert block arguments.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
bool isOpIgnored(Operation *op) const
Returns true if the given operation is ignored, and does not need to be converted.
FailureOr< Block * > convertBlockSignature(Block *block, TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Convert the signature of the given block.
void markNestedOpsIgnored(Operation *op)
Recursively marks the nested operations under 'op' as ignored.
SmallVector< Operation * > createdOps
Ordered vector of all of the newly created operations during conversion.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notifies that a pattern match failed for the given reason.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization, but were not directly repla...
SmallVector< BlockArgument, 4 > argReplacements
Ordered vector of any requested block argument replacements.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
void notifyCreatedBlock(Block *block)
Notifies that a block was created.
SmallVector< unsigned, 4 > operationsWithChangedResults
A vector of indices into replacements of operations that were replaced with values with different res...