MLIR  16.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/ScopedPrinter.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 #define DEBUG_TYPE "dialect-conversion"
28 
29 /// Recursively collect all of the operations to convert from within 'region'.
30 /// If 'target' is nonnull, operations that are recursively legal have their
31 /// regions pre-filtered to avoid considering them for legalization.
32 static LogicalResult
34  Location regionLoc,
36  ConversionTarget *target = nullptr) {
37  if (region.empty())
38  return success();
39 
40  // Traverse starting from the entry block.
41  SmallVector<Block *, 16> worklist(1, &*region.begin());
42  DenseSet<Block *> visitedBlocks;
43  visitedBlocks.insert(worklist.front());
44  while (!worklist.empty()) {
45  Block *block = worklist.pop_back_val();
46 
47  // Compute the conversion set of each of the nested operations.
48  for (Operation &op : *block) {
49  toConvert.emplace_back(&op);
50 
51  // Don't check this operation's children for conversion if the operation
52  // is recursively legal.
53  auto legalityInfo = target ? target->isLegal(&op)
55  if (legalityInfo && legalityInfo->isRecursivelyLegal)
56  continue;
57  for (auto &region : op.getRegions()) {
58  if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
59  toConvert, target)))
60  return failure();
61  }
62  }
63 
64  // Recurse to children that haven't been visited.
65  for (Block *succ : block->getSuccessors())
66  if (visitedBlocks.insert(succ).second)
67  worklist.push_back(succ);
68  }
69 
70  // Check that all blocks in the region were visited.
71  if (llvm::any_of(llvm::drop_begin(region, 1),
72  [&](Block &block) { return !visitedBlocks.count(&block); }))
73  return emitError(regionLoc, "unreachable blocks were not converted");
74  return success();
75 }
76 
77 /// A utility function to log a successful result for the given reason.
78 template <typename... Args>
79 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
80  LLVM_DEBUG({
81  os.unindent();
82  os.startLine() << "} -> SUCCESS";
83  if (!fmt.empty())
84  os.getOStream() << " : "
85  << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
86  os.getOStream() << "\n";
87  });
88 }
89 
90 /// A utility function to log a failure result for the given reason.
91 template <typename... Args>
92 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
93  LLVM_DEBUG({
94  os.unindent();
95  os.startLine() << "} -> FAILURE : "
96  << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
97  << "\n";
98  });
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // ConversionValueMapping
103 //===----------------------------------------------------------------------===//
104 
105 namespace {
106 /// This class wraps a BlockAndValueMapping to provide recursive lookup
107 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
108 struct ConversionValueMapping {
109  /// Lookup a mapped value within the map. If a mapping for the provided value
110  /// does not exist then return the provided value. If `desiredType` is
111  /// non-null, returns the most recently mapped value with that type. If an
112  /// operand of that type does not exist, defaults to normal behavior.
113  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
114 
115  /// Lookup a mapped value within the map, or return null if a mapping does not
116  /// exist. If a mapping exists, this follows the same behavior of
117  /// `lookupOrDefault`.
118  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
119 
120  /// Map a value to the one provided.
121  void map(Value oldVal, Value newVal) {
122  LLVM_DEBUG({
123  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
124  assert(it != oldVal && "inserting cyclic mapping");
125  });
126  mapping.map(oldVal, newVal);
127  }
128 
129  /// Try to map a value to the one provided. Returns false if a transitive
130  /// mapping from the new value to the old value already exists, true if the
131  /// map was updated.
132  bool tryMap(Value oldVal, Value newVal);
133 
134  /// Drop the last mapping for the given value.
135  void erase(Value value) { mapping.erase(value); }
136 
137  /// Returns the inverse raw value mapping (without recursive query support).
138  DenseMap<Value, SmallVector<Value>> getInverse() const {
140  for (auto &it : mapping.getValueMap())
141  inverse[it.second].push_back(it.first);
142  return inverse;
143  }
144 
145 private:
146  /// Current value mappings.
147  BlockAndValueMapping mapping;
148 };
149 } // namespace
150 
151 Value ConversionValueMapping::lookupOrDefault(Value from,
152  Type desiredType) const {
153  // If there was no desired type, simply find the leaf value.
154  if (!desiredType) {
155  // If this value had a valid mapping, unmap that value as well in the case
156  // that it was also replaced.
157  while (auto mappedValue = mapping.lookupOrNull(from))
158  from = mappedValue;
159  return from;
160  }
161 
162  // Otherwise, try to find the deepest value that has the desired type.
163  Value desiredValue;
164  do {
165  if (from.getType() == desiredType)
166  desiredValue = from;
167 
168  Value mappedValue = mapping.lookupOrNull(from);
169  if (!mappedValue)
170  break;
171  from = mappedValue;
172  } while (true);
173 
174  // If the desired value was found use it, otherwise default to the leaf value.
175  return desiredValue ? desiredValue : from;
176 }
177 
178 Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
179  Value result = lookupOrDefault(from, desiredType);
180  if (result == from || (desiredType && result.getType() != desiredType))
181  return nullptr;
182  return result;
183 }
184 
185 bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
186  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
187  if (it == oldVal)
188  return false;
189  map(oldVal, newVal);
190  return true;
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // Rewriter and Translation State
195 //===----------------------------------------------------------------------===//
196 namespace {
197 /// This class contains a snapshot of the current conversion rewriter state.
198 /// This is useful when saving and undoing a set of rewrites.
199 struct RewriterState {
200  RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
201  unsigned numReplacements, unsigned numArgReplacements,
202  unsigned numBlockActions, unsigned numIgnoredOperations,
203  unsigned numRootUpdates)
204  : numCreatedOps(numCreatedOps),
205  numUnresolvedMaterializations(numUnresolvedMaterializations),
206  numReplacements(numReplacements),
207  numArgReplacements(numArgReplacements),
208  numBlockActions(numBlockActions),
209  numIgnoredOperations(numIgnoredOperations),
210  numRootUpdates(numRootUpdates) {}
211 
212  /// The current number of created operations.
213  unsigned numCreatedOps;
214 
215  /// The current number of unresolved materializations.
216  unsigned numUnresolvedMaterializations;
217 
218  /// The current number of replacements queued.
219  unsigned numReplacements;
220 
221  /// The current number of argument replacements queued.
222  unsigned numArgReplacements;
223 
224  /// The current number of block actions performed.
225  unsigned numBlockActions;
226 
227  /// The current number of ignored operations.
228  unsigned numIgnoredOperations;
229 
230  /// The current number of operations that were updated in place.
231  unsigned numRootUpdates;
232 };
233 
234 //===----------------------------------------------------------------------===//
235 // OperationTransactionState
236 
237 /// The state of an operation that was updated by a pattern in-place. This
238 /// contains all of the necessary information to reconstruct an operation that
239 /// was updated in place.
240 class OperationTransactionState {
241 public:
242  OperationTransactionState() = default;
243  OperationTransactionState(Operation *op)
244  : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
245  operands(op->operand_begin(), op->operand_end()),
246  successors(op->successor_begin(), op->successor_end()) {}
247 
248  /// Discard the transaction state and reset the state of the original
249  /// operation.
250  void resetOperation() const {
251  op->setLoc(loc);
252  op->setAttrs(attrs);
253  op->setOperands(operands);
254  for (const auto &it : llvm::enumerate(successors))
255  op->setSuccessor(it.value(), it.index());
256  }
257 
258  /// Return the original operation of this state.
259  Operation *getOperation() const { return op; }
260 
261 private:
262  Operation *op;
263  LocationAttr loc;
264  DictionaryAttr attrs;
265  SmallVector<Value, 8> operands;
266  SmallVector<Block *, 2> successors;
267 };
268 
269 //===----------------------------------------------------------------------===//
270 // OpReplacement
271 
272 /// This class represents one requested operation replacement via 'replaceOp' or
273 /// 'eraseOp`.
274 struct OpReplacement {
275  OpReplacement(TypeConverter *converter = nullptr) : converter(converter) {}
276 
277  /// An optional type converter that can be used to materialize conversions
278  /// between the new and old values if necessary.
279  TypeConverter *converter;
280 };
281 
282 //===----------------------------------------------------------------------===//
283 // BlockAction
284 
285 /// The kind of the block action performed during the rewrite. Actions can be
286 /// undone if the conversion fails.
287 enum class BlockActionKind {
288  Create,
289  Erase,
290  Merge,
291  Move,
292  Split,
293  TypeConversion
294 };
295 
296 /// Original position of the given block in its parent region. During undo
297 /// actions, the block needs to be placed after `insertAfterBlock`.
298 struct BlockPosition {
299  Region *region;
300  Block *insertAfterBlock;
301 };
302 
303 /// Information needed to undo the merge actions.
304 /// - the source block, and
305 /// - the Operation that was the last operation in the dest block before the
306 /// merge (could be null if the dest block was empty).
307 struct MergeInfo {
308  Block *sourceBlock;
309  Operation *destBlockLastInst;
310 };
311 
312 /// The storage class for an undoable block action (one of BlockActionKind),
313 /// contains the information necessary to undo this action.
314 struct BlockAction {
315  static BlockAction getCreate(Block *block) {
316  return {BlockActionKind::Create, block, {}};
317  }
318  static BlockAction getErase(Block *block, BlockPosition originalPosition) {
319  return {BlockActionKind::Erase, block, {originalPosition}};
320  }
321  static BlockAction getMerge(Block *block, Block *sourceBlock) {
322  BlockAction action{BlockActionKind::Merge, block, {}};
323  action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
324  return action;
325  }
326  static BlockAction getMove(Block *block, BlockPosition originalPosition) {
327  return {BlockActionKind::Move, block, {originalPosition}};
328  }
329  static BlockAction getSplit(Block *block, Block *originalBlock) {
330  BlockAction action{BlockActionKind::Split, block, {}};
331  action.originalBlock = originalBlock;
332  return action;
333  }
334  static BlockAction getTypeConversion(Block *block) {
335  return BlockAction{BlockActionKind::TypeConversion, block, {}};
336  }
337 
338  // The action kind.
339  BlockActionKind kind;
340 
341  // A pointer to the block that was created by the action.
342  Block *block;
343 
344  union {
345  // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
346  // contains a pointer to the region that originally contained the block as
347  // well as the position of the block in that region.
348  BlockPosition originalPosition;
349  // In use if kind == BlockActionKind::Split and contains a pointer to the
350  // block that was split into two parts.
351  Block *originalBlock;
352  // In use if kind == BlockActionKind::Merge, and contains the information
353  // needed to undo the merge.
354  MergeInfo mergeInfo;
355  };
356 };
357 
358 //===----------------------------------------------------------------------===//
359 // UnresolvedMaterialization
360 
361 /// This class represents an unresolved materialization, i.e. a materialization
362 /// that was inserted during conversion that needs to be legalized at the end of
363 /// the conversion process.
364 class UnresolvedMaterialization {
365 public:
366  /// The type of materialization.
367  enum Kind {
368  /// This materialization materializes a conversion for an illegal block
369  /// argument type, to a legal one.
370  Argument,
371 
372  /// This materialization materializes a conversion from an illegal type to a
373  /// legal one.
374  Target
375  };
376 
377  UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
378  TypeConverter *converter = nullptr,
379  Kind kind = Target, Type origOutputType = nullptr)
380  : op(op), converterAndKind(converter, kind),
381  origOutputType(origOutputType) {}
382 
383  /// Return the temporary conversion operation inserted for this
384  /// materialization.
385  UnrealizedConversionCastOp getOp() const { return op; }
386 
387  /// Return the type converter of this materialization (which may be null).
388  TypeConverter *getConverter() const { return converterAndKind.getPointer(); }
389 
390  /// Return the kind of this materialization.
391  Kind getKind() const { return converterAndKind.getInt(); }
392 
393  /// Set the kind of this materialization.
394  void setKind(Kind kind) { converterAndKind.setInt(kind); }
395 
396  /// Return the original illegal output type of the input values.
397  Type getOrigOutputType() const { return origOutputType; }
398 
399 private:
400  /// The unresolved materialization operation created during conversion.
401  UnrealizedConversionCastOp op;
402 
403  /// The corresponding type converter to use when resolving this
404  /// materialization, and the kind of this materialization.
405  llvm::PointerIntPair<TypeConverter *, 1, Kind> converterAndKind;
406 
407  /// The original output type. This is only used for argument conversions.
408  Type origOutputType;
409 };
410 } // namespace
411 
412 /// Build an unresolved materialization operation given an output type and set
413 /// of input operands.
415  UnresolvedMaterialization::Kind kind, Block *insertBlock,
416  Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
417  Type origOutputType, TypeConverter *converter,
418  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
419  // Avoid materializing an unnecessary cast.
420  if (inputs.size() == 1 && inputs.front().getType() == outputType)
421  return inputs.front();
422 
423  // Create an unresolved materialization. We use a new OpBuilder to avoid
424  // tracking the materialization like we do for other operations.
425  OpBuilder builder(insertBlock, insertPt);
426  auto convertOp =
427  builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
428  unresolvedMaterializations.emplace_back(convertOp, converter, kind,
429  origOutputType);
430  return convertOp.getResult(0);
431 }
433  PatternRewriter &rewriter, Location loc, ValueRange inputs,
434  Type origOutputType, Type outputType, TypeConverter *converter,
435  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
438  rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
439  converter, unresolvedMaterializations);
440 }
442  Location loc, Value input, Type outputType, TypeConverter *converter,
443  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
444  Block *insertBlock = input.getParentBlock();
445  Block::iterator insertPt = insertBlock->begin();
446  if (OpResult inputRes = input.dyn_cast<OpResult>())
447  insertPt = ++inputRes.getOwner()->getIterator();
448 
450  UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
451  outputType, outputType, converter, unresolvedMaterializations);
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // ArgConverter
456 //===----------------------------------------------------------------------===//
457 namespace {
458 /// This class provides a simple interface for converting the types of block
459 /// arguments. This is done by creating a new block that contains the new legal
460 /// types and extracting the block that contains the old illegal types to allow
461 /// for undoing pending rewrites in the case of failure.
462 struct ArgConverter {
463  ArgConverter(
464  PatternRewriter &rewriter,
465  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
466  : rewriter(rewriter),
467  unresolvedMaterializations(unresolvedMaterializations) {}
468 
469  /// This structure contains the information pertaining to an argument that has
470  /// been converted.
471  struct ConvertedArgInfo {
472  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
473  Value castValue = nullptr)
474  : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
475 
476  /// The start index of in the new argument list that contains arguments that
477  /// replace the original.
478  unsigned newArgIdx;
479 
480  /// The number of arguments that replaced the original argument.
481  unsigned newArgSize;
482 
483  /// The cast value that was created to cast from the new arguments to the
484  /// old. This only used if 'newArgSize' > 1.
485  Value castValue;
486  };
487 
488  /// This structure contains information pertaining to a block that has had its
489  /// signature converted.
490  struct ConvertedBlockInfo {
491  ConvertedBlockInfo(Block *origBlock, TypeConverter *converter)
492  : origBlock(origBlock), converter(converter) {}
493 
494  /// The original block that was requested to have its signature converted.
495  Block *origBlock;
496 
497  /// The conversion information for each of the arguments. The information is
498  /// None if the argument was dropped during conversion.
500 
501  /// The type converter used to convert the arguments.
502  TypeConverter *converter;
503  };
504 
505  /// Return if the signature of the given block has already been converted.
506  bool hasBeenConverted(Block *block) const {
507  return conversionInfo.count(block) || convertedBlocks.count(block);
508  }
509 
510  /// Set the type converter to use for the given region.
511  void setConverter(Region *region, TypeConverter *typeConverter) {
512  assert(typeConverter && "expected valid type converter");
513  regionToConverter[region] = typeConverter;
514  }
515 
516  /// Return the type converter to use for the given region, or null if there
517  /// isn't one.
518  TypeConverter *getConverter(Region *region) {
519  return regionToConverter.lookup(region);
520  }
521 
522  //===--------------------------------------------------------------------===//
523  // Rewrite Application
524  //===--------------------------------------------------------------------===//
525 
526  /// Erase any rewrites registered for the blocks within the given operation
527  /// which is about to be removed. This merely drops the rewrites without
528  /// undoing them.
529  void notifyOpRemoved(Operation *op);
530 
531  /// Cleanup and undo any generated conversions for the arguments of block.
532  /// This method replaces the new block with the original, reverting the IR to
533  /// its original state.
534  void discardRewrites(Block *block);
535 
536  /// Fully replace uses of the old arguments with the new.
537  void applyRewrites(ConversionValueMapping &mapping);
538 
539  /// Materialize any necessary conversions for converted arguments that have
540  /// live users, using the provided `findLiveUser` to search for a user that
541  /// survives the conversion process.
543  materializeLiveConversions(ConversionValueMapping &mapping,
544  OpBuilder &builder,
545  function_ref<Operation *(Value)> findLiveUser);
546 
547  //===--------------------------------------------------------------------===//
548  // Conversion
549  //===--------------------------------------------------------------------===//
550 
551  /// Attempt to convert the signature of the given block, if successful a new
552  /// block is returned containing the new arguments. Returns `block` if it did
553  /// not require conversion.
555  convertSignature(Block *block, TypeConverter *converter,
556  ConversionValueMapping &mapping,
557  SmallVectorImpl<BlockArgument> &argReplacements);
558 
559  /// Apply the given signature conversion on the given block. The new block
560  /// containing the updated signature is returned. If no conversions were
561  /// necessary, e.g. if the block has no arguments, `block` is returned.
562  /// `converter` is used to generate any necessary cast operations that
563  /// translate between the origin argument types and those specified in the
564  /// signature conversion.
565  Block *applySignatureConversion(
566  Block *block, TypeConverter *converter,
567  TypeConverter::SignatureConversion &signatureConversion,
568  ConversionValueMapping &mapping,
569  SmallVectorImpl<BlockArgument> &argReplacements);
570 
571  /// Insert a new conversion into the cache.
572  void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
573 
574  /// A collection of blocks that have had their arguments converted. This is a
575  /// map from the new replacement block, back to the original block.
576  llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
577 
578  /// The set of original blocks that were converted.
579  DenseSet<Block *> convertedBlocks;
580 
581  /// A mapping from valid regions, to those containing the original blocks of a
582  /// conversion.
584 
585  /// A mapping of regions to type converters that should be used when
586  /// converting the arguments of blocks within that region.
587  DenseMap<Region *, TypeConverter *> regionToConverter;
588 
589  /// The pattern rewriter to use when materializing conversions.
590  PatternRewriter &rewriter;
591 
592  /// An ordered set of unresolved materializations during conversion.
593  SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
594 };
595 } // namespace
596 
597 //===----------------------------------------------------------------------===//
598 // Rewrite Application
599 
600 void ArgConverter::notifyOpRemoved(Operation *op) {
601  if (conversionInfo.empty())
602  return;
603 
604  for (Region &region : op->getRegions()) {
605  for (Block &block : region) {
606  // Drop any rewrites from within.
607  for (Operation &nestedOp : block)
608  if (nestedOp.getNumRegions())
609  notifyOpRemoved(&nestedOp);
610 
611  // Check if this block was converted.
612  auto it = conversionInfo.find(&block);
613  if (it == conversionInfo.end())
614  continue;
615 
616  // Drop all uses of the original arguments and delete the original block.
617  Block *origBlock = it->second.origBlock;
618  for (BlockArgument arg : origBlock->getArguments())
619  arg.dropAllUses();
620  conversionInfo.erase(it);
621  }
622  }
623 }
624 
625 void ArgConverter::discardRewrites(Block *block) {
626  auto it = conversionInfo.find(block);
627  if (it == conversionInfo.end())
628  return;
629  Block *origBlock = it->second.origBlock;
630 
631  // Drop all uses of the new block arguments and replace uses of the new block.
632  for (int i = block->getNumArguments() - 1; i >= 0; --i)
633  block->getArgument(i).dropAllUses();
634  block->replaceAllUsesWith(origBlock);
635 
636  // Move the operations back the original block and the delete the new block.
637  origBlock->getOperations().splice(origBlock->end(), block->getOperations());
638  origBlock->moveBefore(block);
639  block->erase();
640 
641  convertedBlocks.erase(origBlock);
642  conversionInfo.erase(it);
643 }
644 
645 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
646  for (auto &info : conversionInfo) {
647  ConvertedBlockInfo &blockInfo = info.second;
648  Block *origBlock = blockInfo.origBlock;
649 
650  // Process the remapping for each of the original arguments.
651  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
652  Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
653  BlockArgument origArg = origBlock->getArgument(i);
654 
655  // Handle the case of a 1->0 value mapping.
656  if (!argInfo) {
657  if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
658  origArg.replaceAllUsesWith(newArg);
659  continue;
660  }
661 
662  // Otherwise this is a 1->1+ value mapping.
663  Value castValue = argInfo->castValue;
664  assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
665 
666  // If the argument is still used, replace it with the generated cast.
667  if (!origArg.use_empty()) {
668  origArg.replaceAllUsesWith(
669  mapping.lookupOrDefault(castValue, origArg.getType()));
670  }
671  }
672  }
673 }
674 
675 LogicalResult ArgConverter::materializeLiveConversions(
676  ConversionValueMapping &mapping, OpBuilder &builder,
677  function_ref<Operation *(Value)> findLiveUser) {
678  for (auto &info : conversionInfo) {
679  Block *newBlock = info.first;
680  ConvertedBlockInfo &blockInfo = info.second;
681  Block *origBlock = blockInfo.origBlock;
682 
683  // Process the remapping for each of the original arguments.
684  for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
685  // If the type of this argument changed and the argument is still live, we
686  // need to materialize a conversion.
687  BlockArgument origArg = origBlock->getArgument(i);
688  if (mapping.lookupOrNull(origArg, origArg.getType()))
689  continue;
690  Operation *liveUser = findLiveUser(origArg);
691  if (!liveUser)
692  continue;
693 
694  Value replacementValue = mapping.lookupOrDefault(origArg);
695  bool isDroppedArg = replacementValue == origArg;
696  if (isDroppedArg)
697  rewriter.setInsertionPointToStart(newBlock);
698  else
699  rewriter.setInsertionPointAfterValue(replacementValue);
700  Value newArg;
701  if (blockInfo.converter) {
702  newArg = blockInfo.converter->materializeSourceConversion(
703  rewriter, origArg.getLoc(), origArg.getType(),
704  isDroppedArg ? ValueRange() : ValueRange(replacementValue));
705  assert((!newArg || newArg.getType() == origArg.getType()) &&
706  "materialization hook did not provide a value of the expected "
707  "type");
708  }
709  if (!newArg) {
711  emitError(origArg.getLoc())
712  << "failed to materialize conversion for block argument #" << i
713  << " that remained live after conversion, type was "
714  << origArg.getType();
715  if (!isDroppedArg)
716  diag << ", with target type " << replacementValue.getType();
717  diag.attachNote(liveUser->getLoc())
718  << "see existing live user here: " << *liveUser;
719  return failure();
720  }
721  mapping.map(origArg, newArg);
722  }
723  }
724  return success();
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // Conversion
729 
730 FailureOr<Block *> ArgConverter::convertSignature(
731  Block *block, TypeConverter *converter, ConversionValueMapping &mapping,
732  SmallVectorImpl<BlockArgument> &argReplacements) {
733  // Check if the block was already converted. If the block is detached,
734  // conservatively assume it is going to be deleted.
735  if (hasBeenConverted(block) || !block->getParent())
736  return block;
737  // If a converter wasn't provided, and the block wasn't already converted,
738  // there is nothing we can do.
739  if (!converter)
740  return failure();
741 
742  // Try to convert the signature for the block with the provided converter.
743  if (auto conversion = converter->convertBlockSignature(block))
744  return applySignatureConversion(block, converter, *conversion, mapping,
745  argReplacements);
746  return failure();
747 }
748 
749 Block *ArgConverter::applySignatureConversion(
750  Block *block, TypeConverter *converter,
751  TypeConverter::SignatureConversion &signatureConversion,
752  ConversionValueMapping &mapping,
753  SmallVectorImpl<BlockArgument> &argReplacements) {
754  // If no arguments are being changed or added, there is nothing to do.
755  unsigned origArgCount = block->getNumArguments();
756  auto convertedTypes = signatureConversion.getConvertedTypes();
757  if (origArgCount == 0 && convertedTypes.empty())
758  return block;
759 
760  // Split the block at the beginning to get a new block to use for the updated
761  // signature.
762  Block *newBlock = block->splitBlock(block->begin());
763  block->replaceAllUsesWith(newBlock);
764 
765  // FIXME: We should map the new arguments to proper locations.
766  SmallVector<Location> newLocs(convertedTypes.size(),
767  rewriter.getUnknownLoc());
768  SmallVector<Value, 4> newArgRange(
769  newBlock->addArguments(convertedTypes, newLocs));
770  ArrayRef<Value> newArgs(newArgRange);
771 
772  // Remap each of the original arguments as determined by the signature
773  // conversion.
774  ConvertedBlockInfo info(block, converter);
775  info.argInfo.resize(origArgCount);
776 
777  OpBuilder::InsertionGuard guard(rewriter);
778  rewriter.setInsertionPointToStart(newBlock);
779  for (unsigned i = 0; i != origArgCount; ++i) {
780  auto inputMap = signatureConversion.getInputMapping(i);
781  if (!inputMap)
782  continue;
783  BlockArgument origArg = block->getArgument(i);
784 
785  // If inputMap->replacementValue is not nullptr, then the argument is
786  // dropped and a replacement value is provided to be the remappedValue.
787  if (inputMap->replacementValue) {
788  assert(inputMap->size == 0 &&
789  "invalid to provide a replacement value when the argument isn't "
790  "dropped");
791  mapping.map(origArg, inputMap->replacementValue);
792  argReplacements.push_back(origArg);
793  continue;
794  }
795 
796  // Otherwise, this is a 1->1+ mapping.
797  auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
798  Value newArg;
799 
800  // If this is a 1->1 mapping and the types of new and replacement arguments
801  // match (i.e. it's an identity map), then the argument is mapped to its
802  // original type.
803  // FIXME: We simply pass through the replacement argument if there wasn't a
804  // converter, which isn't great as it allows implicit type conversions to
805  // appear. We should properly restructure this code to handle cases where a
806  // converter isn't provided and also to properly handle the case where an
807  // argument materialization is actually a temporary source materialization
808  // (e.g. in the case of 1->N).
809  if (replArgs.size() == 1 &&
810  (!converter || replArgs[0].getType() == origArg.getType())) {
811  newArg = replArgs.front();
812  } else {
813  Type origOutputType = origArg.getType();
814 
815  // Legalize the argument output type.
816  Type outputType = origOutputType;
817  if (Type legalOutputType = converter->convertType(outputType))
818  outputType = legalOutputType;
819 
821  rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
822  converter, unresolvedMaterializations);
823  }
824 
825  mapping.map(origArg, newArg);
826  argReplacements.push_back(origArg);
827  info.argInfo[i] =
828  ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
829  }
830 
831  // Remove the original block from the region and return the new one.
832  insertConversion(newBlock, std::move(info));
833  return newBlock;
834 }
835 
836 void ArgConverter::insertConversion(Block *newBlock,
837  ConvertedBlockInfo &&info) {
838  // Get a region to insert the old block.
839  Region *region = newBlock->getParent();
840  std::unique_ptr<Region> &mappedRegion = regionMapping[region];
841  if (!mappedRegion)
842  mappedRegion = std::make_unique<Region>(region->getParentOp());
843 
844  // Move the original block to the mapped region and emplace the conversion.
845  mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
846  info.origBlock->getIterator());
847  convertedBlocks.insert(info.origBlock);
848  conversionInfo.insert({newBlock, std::move(info)});
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // ConversionPatternRewriterImpl
853 //===----------------------------------------------------------------------===//
854 namespace mlir {
855 namespace detail {
858  : argConverter(rewriter, unresolvedMaterializations),
859  notifyCallback(nullptr) {}
860 
861  /// Cleanup and destroy any generated rewrite operations. This method is
862  /// invoked when the conversion process fails.
863  void discardRewrites();
864 
865  /// Apply all requested operation rewrites. This method is invoked when the
866  /// conversion process succeeds.
867  void applyRewrites();
868 
869  //===--------------------------------------------------------------------===//
870  // State Management
871  //===--------------------------------------------------------------------===//
872 
873  /// Return the current state of the rewriter.
874  RewriterState getCurrentState();
875 
876  /// Reset the state of the rewriter to a previously saved point.
877  void resetState(RewriterState state);
878 
879  /// Erase any blocks that were unlinked from their regions and stored in block
880  /// actions.
881  void eraseDanglingBlocks();
882 
883  /// Undo the block actions (motions, splits) one by one in reverse order until
884  /// "numActionsToKeep" actions remains.
885  void undoBlockActions(unsigned numActionsToKeep = 0);
886 
887  /// Remap the given values to those with potentially different types. Returns
888  /// success if the values could be remapped, failure otherwise. `valueDiagTag`
889  /// is the tag used when describing a value within a diagnostic, e.g.
890  /// "operand".
891  LogicalResult remapValues(StringRef valueDiagTag, Optional<Location> inputLoc,
892  PatternRewriter &rewriter, ValueRange values,
893  SmallVectorImpl<Value> &remapped);
894 
895  /// Returns true if the given operation is ignored, and does not need to be
896  /// converted.
897  bool isOpIgnored(Operation *op) const;
898 
899  /// Recursively marks the nested operations under 'op' as ignored. This
900  /// removes them from being considered for legalization.
901  void markNestedOpsIgnored(Operation *op);
902 
903  //===--------------------------------------------------------------------===//
904  // Type Conversion
905  //===--------------------------------------------------------------------===//
906 
907  /// Convert the signature of the given block.
908  FailureOr<Block *> convertBlockSignature(
909  Block *block, TypeConverter *converter,
910  TypeConverter::SignatureConversion *conversion = nullptr);
911 
912  /// Apply a signature conversion on the given region, using `converter` for
913  /// materializations if not null.
914  Block *
915  applySignatureConversion(Region *region,
917  TypeConverter *converter);
918 
919  /// Convert the types of block arguments within the given region.
921  convertRegionTypes(Region *region, TypeConverter &converter,
922  TypeConverter::SignatureConversion *entryConversion);
923 
924  /// Convert the types of non-entry block arguments within the given region.
925  LogicalResult convertNonEntryRegionTypes(
926  Region *region, TypeConverter &converter,
927  ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
928 
929  //===--------------------------------------------------------------------===//
930  // Rewriter Notification Hooks
931  //===--------------------------------------------------------------------===//
932 
933  /// PatternRewriter hook for replacing the results of an operation.
934  void notifyOpReplaced(Operation *op, ValueRange newValues);
935 
936  /// Notifies that a block is about to be erased.
937  void notifyBlockIsBeingErased(Block *block);
938 
939  /// Notifies that a block was created.
940  void notifyCreatedBlock(Block *block);
941 
942  /// Notifies that a block was split.
943  void notifySplitBlock(Block *block, Block *continuation);
944 
945  /// Notifies that `block` is being merged with `srcBlock`.
946  void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
947 
948  /// Notifies that the blocks of a region are about to be moved.
949  void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
950  Region::iterator before);
951 
952  /// Notifies that the blocks of a region were cloned into another.
953  void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
954  Location origRegionLoc);
955 
956  /// Notifies that a pattern match failed for the given reason.
958  notifyMatchFailure(Location loc,
959  function_ref<void(Diagnostic &)> reasonCallback);
960 
961  //===--------------------------------------------------------------------===//
962  // State
963  //===--------------------------------------------------------------------===//
964 
965  // Mapping between replaced values that differ in type. This happens when
966  // replacing a value with one of a different type.
967  ConversionValueMapping mapping;
968 
969  /// Utility used to convert block arguments.
970  ArgConverter argConverter;
971 
972  /// Ordered vector of all of the newly created operations during conversion.
974 
975  /// Ordered vector of all unresolved type conversion materializations during
976  /// conversion.
978 
979  /// Ordered map of requested operation replacements.
980  llvm::MapVector<Operation *, OpReplacement> replacements;
981 
982  /// Ordered vector of any requested block argument replacements.
984 
985  /// Ordered list of block operations (creations, splits, motions).
987 
988  /// A set of operations that should no longer be considered for legalization,
989  /// but were not directly replace/erased/etc. by a pattern. These are
990  /// generally child operations of other operations who were
991  /// replaced/erased/etc. This is not meant to be an exhaustive list of all
992  /// operations, but the minimal set that can be used to detect if a given
993  /// operation should be `ignored`. For example, we may add the operations that
994  /// define non-empty regions to the set, but not any of the others. This
995  /// simplifies the amount of memory needed as we can query if the parent
996  /// operation was ignored.
998 
999  /// A transaction state for each of operations that were updated in-place.
1001 
1002  /// A vector of indices into `replacements` of operations that were replaced
1003  /// with values with different result types than the original operation, e.g.
1004  /// 1->N conversion of some kind.
1006 
1007  /// The current type converter, or nullptr if no type converter is currently
1008  /// active.
1009  TypeConverter *currentTypeConverter = nullptr;
1010 
1011  /// This allows the user to collect the match failure message.
1013 
1014 #ifndef NDEBUG
1015  /// A set of operations that have pending updates. This tracking isn't
1016  /// strictly necessary, and is thus only active during debug builds for extra
1017  /// verification.
1019 
1020  /// A logger used to emit diagnostics during the conversion process.
1021  llvm::ScopedPrinter logger{llvm::dbgs()};
1022 #endif
1023 };
1024 } // namespace detail
1025 } // namespace mlir
1026 
1027 /// Detach any operations nested in the given operation from their parent
1028 /// blocks, and erase the given operation. This can be used when the nested
1029 /// operations are scheduled for erasure themselves, so deleting the regions of
1030 /// the given operation together with their content would result in double-free.
1031 /// This happens, for example, when rolling back op creation in the reverse
1032 /// order and if the nested ops were created before the parent op. This function
1033 /// does not need to collect nested ops recursively because it is expected to
1034 /// also be called for each nested op when it is about to be deleted.
1036  for (Region &region : op->getRegions()) {
1037  for (Block &block : region.getBlocks()) {
1038  while (!block.getOperations().empty())
1039  block.getOperations().remove(block.getOperations().begin());
1040  block.dropAllDefinedValueUses();
1041  }
1042  }
1043  op->dropAllUses();
1044  op->erase();
1045 }
1046 
1048  // Reset any operations that were updated in place.
1049  for (auto &state : rootUpdates)
1050  state.resetOperation();
1051 
1052  undoBlockActions();
1053 
1054  // Remove any newly created ops.
1055  for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
1056  detachNestedAndErase(materialization.getOp());
1057  for (auto *op : llvm::reverse(createdOps))
1059 }
1060 
1062  // Apply all of the rewrites replacements requested during conversion.
1063  for (auto &repl : replacements) {
1064  for (OpResult result : repl.first->getResults())
1065  if (Value newValue = mapping.lookupOrNull(result, result.getType()))
1066  result.replaceAllUsesWith(newValue);
1067 
1068  // If this operation defines any regions, drop any pending argument
1069  // rewrites.
1070  if (repl.first->getNumRegions())
1071  argConverter.notifyOpRemoved(repl.first);
1072  }
1073 
1074  // Apply all of the requested argument replacements.
1075  for (BlockArgument arg : argReplacements) {
1076  Value repl = mapping.lookupOrNull(arg, arg.getType());
1077  if (!repl)
1078  continue;
1079 
1080  if (repl.isa<BlockArgument>()) {
1081  arg.replaceAllUsesWith(repl);
1082  continue;
1083  }
1084 
1085  // If the replacement value is an operation, we check to make sure that we
1086  // don't replace uses that are within the parent operation of the
1087  // replacement value.
1088  Operation *replOp = repl.cast<OpResult>().getOwner();
1089  Block *replBlock = replOp->getBlock();
1090  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
1091  Operation *user = operand.getOwner();
1092  return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1093  });
1094  }
1095 
1096  // Drop all of the unresolved materialization operations created during
1097  // conversion.
1098  for (auto &mat : unresolvedMaterializations) {
1099  mat.getOp()->dropAllUses();
1100  mat.getOp()->erase();
1101  }
1102 
1103  // In a second pass, erase all of the replaced operations in reverse. This
1104  // allows processing nested operations before their parent region is
1105  // destroyed. Because we process in reverse order, producers may be deleted
1106  // before their users (a pattern deleting a producer and then the consumer)
1107  // so we first drop all uses explicitly.
1108  for (auto &repl : llvm::reverse(replacements)) {
1109  repl.first->dropAllUses();
1110  repl.first->erase();
1111  }
1112 
1113  argConverter.applyRewrites(mapping);
1114 
1115  // Now that the ops have been erased, also erase dangling blocks.
1117 }
1118 
1119 //===----------------------------------------------------------------------===//
1120 // State Management
1121 
1123  return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
1124  replacements.size(), argReplacements.size(),
1125  blockActions.size(), ignoredOps.size(),
1126  rootUpdates.size());
1127 }
1128 
1130  // Reset any operations that were updated in place.
1131  for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
1132  rootUpdates[i].resetOperation();
1133  rootUpdates.resize(state.numRootUpdates);
1134 
1135  // Reset any replaced arguments.
1136  for (BlockArgument replacedArg :
1137  llvm::drop_begin(argReplacements, state.numArgReplacements))
1138  mapping.erase(replacedArg);
1139  argReplacements.resize(state.numArgReplacements);
1140 
1141  // Undo any block actions.
1142  undoBlockActions(state.numBlockActions);
1143 
1144  // Reset any replaced operations and undo any saved mappings.
1145  for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
1146  for (auto result : repl.first->getResults())
1147  mapping.erase(result);
1148  while (replacements.size() != state.numReplacements)
1149  replacements.pop_back();
1150 
1151  // Pop all of the newly inserted materializations.
1152  while (unresolvedMaterializations.size() !=
1153  state.numUnresolvedMaterializations) {
1154  UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
1155  UnrealizedConversionCastOp op = mat.getOp();
1156 
1157  // If this was a target materialization, drop the mapping that was inserted.
1158  if (mat.getKind() == UnresolvedMaterialization::Target) {
1159  for (Value input : op->getOperands())
1160  mapping.erase(input);
1161  }
1163  }
1164 
1165  // Pop all of the newly created operations.
1166  while (createdOps.size() != state.numCreatedOps) {
1168  createdOps.pop_back();
1169  }
1170 
1171  // Pop all of the recorded ignored operations that are no longer valid.
1172  while (ignoredOps.size() != state.numIgnoredOperations)
1173  ignoredOps.pop_back();
1174 
1175  // Reset operations with changed results.
1176  while (!operationsWithChangedResults.empty() &&
1177  operationsWithChangedResults.back() >= state.numReplacements)
1178  operationsWithChangedResults.pop_back();
1179 }
1180 
1182  for (auto &action : blockActions)
1183  if (action.kind == BlockActionKind::Erase)
1184  delete action.block;
1185 }
1186 
1188  unsigned numActionsToKeep) {
1189  for (auto &action :
1190  llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
1191  switch (action.kind) {
1192  // Delete the created block.
1193  case BlockActionKind::Create: {
1194  // Unlink all of the operations within this block, they will be deleted
1195  // separately.
1196  auto &blockOps = action.block->getOperations();
1197  while (!blockOps.empty())
1198  blockOps.remove(blockOps.begin());
1199  action.block->dropAllDefinedValueUses();
1200  action.block->erase();
1201  break;
1202  }
1203  // Put the block (owned by action) back into its original position.
1204  case BlockActionKind::Erase: {
1205  auto &blockList = action.originalPosition.region->getBlocks();
1206  Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1207  blockList.insert((insertAfterBlock
1208  ? std::next(Region::iterator(insertAfterBlock))
1209  : blockList.begin()),
1210  action.block);
1211  break;
1212  }
1213  // Split the block at the position which was originally the end of the
1214  // destination block (owned by action), and put the instructions back into
1215  // the block used before the merge.
1216  case BlockActionKind::Merge: {
1217  Block *sourceBlock = action.mergeInfo.sourceBlock;
1218  Block::iterator splitPoint =
1219  (action.mergeInfo.destBlockLastInst
1220  ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
1221  : action.block->begin());
1222  sourceBlock->getOperations().splice(sourceBlock->begin(),
1223  action.block->getOperations(),
1224  splitPoint, action.block->end());
1225  break;
1226  }
1227  // Move the block back to its original position.
1228  case BlockActionKind::Move: {
1229  Region *originalRegion = action.originalPosition.region;
1230  Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1231  originalRegion->getBlocks().splice(
1232  (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1233  : originalRegion->end()),
1234  action.block->getParent()->getBlocks(), action.block);
1235  break;
1236  }
1237  // Merge back the block that was split out.
1238  case BlockActionKind::Split: {
1239  action.originalBlock->getOperations().splice(
1240  action.originalBlock->end(), action.block->getOperations());
1241  action.block->dropAllDefinedValueUses();
1242  action.block->erase();
1243  break;
1244  }
1245  // Undo the type conversion.
1246  case BlockActionKind::TypeConversion: {
1247  argConverter.discardRewrites(action.block);
1248  break;
1249  }
1250  }
1251  }
1252  blockActions.resize(numActionsToKeep);
1253 }
1254 
1256  StringRef valueDiagTag, Optional<Location> inputLoc,
1257  PatternRewriter &rewriter, ValueRange values,
1258  SmallVectorImpl<Value> &remapped) {
1259  remapped.reserve(llvm::size(values));
1260 
1261  SmallVector<Type, 1> legalTypes;
1262  for (const auto &it : llvm::enumerate(values)) {
1263  Value operand = it.value();
1264  Type origType = operand.getType();
1265 
1266  // If a converter was provided, get the desired legal types for this
1267  // operand.
1268  Type desiredType;
1269  if (currentTypeConverter) {
1270  // If there is no legal conversion, fail to match this pattern.
1271  legalTypes.clear();
1272  if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1273  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1274  return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1275  diag << "unable to convert type for " << valueDiagTag << " #"
1276  << it.index() << ", type was " << origType;
1277  });
1278  }
1279  // TODO: There currently isn't any mechanism to do 1->N type conversion
1280  // via the PatternRewriter replacement API, so for now we just ignore it.
1281  if (legalTypes.size() == 1)
1282  desiredType = legalTypes.front();
1283  } else {
1284  // TODO: What we should do here is just set `desiredType` to `origType`
1285  // and then handle the necessary type conversions after the conversion
1286  // process has finished. Unfortunately a lot of patterns currently rely on
1287  // receiving the new operands even if the types change, so we keep the
1288  // original behavior here for now until all of the patterns relying on
1289  // this get updated.
1290  }
1291  Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1292 
1293  // Handle the case where the conversion was 1->1 and the new operand type
1294  // isn't legal.
1295  Type newOperandType = newOperand.getType();
1296  if (currentTypeConverter && desiredType && newOperandType != desiredType) {
1297  Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1299  operandLoc, newOperand, desiredType, currentTypeConverter,
1301  mapping.map(mapping.lookupOrDefault(newOperand), castValue);
1302  newOperand = castValue;
1303  }
1304  remapped.push_back(newOperand);
1305  }
1306  return success();
1307 }
1308 
1310  // Check to see if this operation was replaced or its parent ignored.
1311  return replacements.count(op) || ignoredOps.count(op->getParentOp());
1312 }
1313 
1315  // Walk this operation and collect nested operations that define non-empty
1316  // regions. We mark such operations as 'ignored' so that we know we don't have
1317  // to convert them, or their nested ops.
1318  if (op->getNumRegions() == 0)
1319  return;
1320  op->walk([&](Operation *op) {
1321  if (llvm::any_of(op->getRegions(),
1322  [](Region &region) { return !region.empty(); }))
1323  ignoredOps.insert(op);
1324  });
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // Type Conversion
1329 
1331  Block *block, TypeConverter *converter,
1332  TypeConverter::SignatureConversion *conversion) {
1333  FailureOr<Block *> result =
1334  conversion ? argConverter.applySignatureConversion(
1335  block, converter, *conversion, mapping, argReplacements)
1336  : argConverter.convertSignature(block, converter, mapping,
1337  argReplacements);
1338  if (failed(result))
1339  return failure();
1340  if (Block *newBlock = *result) {
1341  if (newBlock != block)
1342  blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1343  }
1344  return result;
1345 }
1346 
1348  Region *region, TypeConverter::SignatureConversion &conversion,
1349  TypeConverter *converter) {
1350  if (!region->empty())
1351  return *convertBlockSignature(&region->front(), converter, &conversion);
1352  return nullptr;
1353 }
1354 
1356  Region *region, TypeConverter &converter,
1357  TypeConverter::SignatureConversion *entryConversion) {
1358  argConverter.setConverter(region, &converter);
1359  if (region->empty())
1360  return nullptr;
1361 
1362  if (failed(convertNonEntryRegionTypes(region, converter)))
1363  return failure();
1364 
1365  FailureOr<Block *> newEntry =
1366  convertBlockSignature(&region->front(), &converter, entryConversion);
1367  return newEntry;
1368 }
1369 
1371  Region *region, TypeConverter &converter,
1373  argConverter.setConverter(region, &converter);
1374  if (region->empty())
1375  return success();
1376 
1377  // Convert the arguments of each block within the region.
1378  int blockIdx = 0;
1379  assert((blockConversions.empty() ||
1380  blockConversions.size() == region->getBlocks().size() - 1) &&
1381  "expected either to provide no SignatureConversions at all or to "
1382  "provide a SignatureConversion for each non-entry block");
1383 
1384  for (Block &block :
1385  llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1386  TypeConverter::SignatureConversion *blockConversion =
1387  blockConversions.empty()
1388  ? nullptr
1389  : const_cast<TypeConverter::SignatureConversion *>(
1390  &blockConversions[blockIdx++]);
1391 
1392  if (failed(convertBlockSignature(&block, &converter, blockConversion)))
1393  return failure();
1394  }
1395  return success();
1396 }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // Rewriter Notification Hooks
1400 
1402  ValueRange newValues) {
1403  assert(newValues.size() == op->getNumResults());
1404  assert(!replacements.count(op) && "operation was already replaced");
1405 
1406  // Track if any of the results changed, e.g. erased and replaced with null.
1407  bool resultChanged = false;
1408 
1409  // Create mappings for each of the new result values.
1410  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
1411  if (!newValue) {
1412  resultChanged = true;
1413  continue;
1414  }
1415  // Remap, and check for any result type changes.
1416  mapping.map(result, newValue);
1417  resultChanged |= (newValue.getType() != result.getType());
1418  }
1419  if (resultChanged)
1420  operationsWithChangedResults.push_back(replacements.size());
1421 
1422  // Record the requested operation replacement.
1423  replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
1424 
1425  // Mark this operation as recursively ignored so that we don't need to
1426  // convert any nested operations.
1428 }
1429 
1431  Region *region = block->getParent();
1432  Block *origPrevBlock = block->getPrevNode();
1433  blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1434 }
1435 
1437  blockActions.push_back(BlockAction::getCreate(block));
1438 }
1439 
1441  Block *continuation) {
1442  blockActions.push_back(BlockAction::getSplit(continuation, block));
1443 }
1444 
1446  Block *srcBlock) {
1447  blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1448 }
1449 
1451  Region &region, Region &parent, Region::iterator before) {
1452  if (region.empty())
1453  return;
1454  Block *laterBlock = &region.back();
1455  for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1456  blockActions.push_back(
1457  BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1458  laterBlock = &earlierBlock;
1459  }
1460  blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1461 }
1462 
1464  iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1465  for (Block &block : blocks)
1466  blockActions.push_back(BlockAction::getCreate(&block));
1467 
1468  // Compute the conversion set for the inlined region.
1469  auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1470 
1471  // This original region has already had its conversion set computed, so there
1472  // shouldn't be any new failures.
1473  (void)result;
1474  assert(succeeded(result) && "expected region to have no unreachable blocks");
1475 }
1476 
1478  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1479  LLVM_DEBUG({
1481  reasonCallback(diag);
1482  logger.startLine() << "** Failure : " << diag.str() << "\n";
1483  if (notifyCallback)
1485  });
1486  return failure();
1487 }
1488 
1489 //===----------------------------------------------------------------------===//
1490 // ConversionPatternRewriter
1491 //===----------------------------------------------------------------------===//
1492 
1494  : PatternRewriter(ctx),
1495  impl(new detail::ConversionPatternRewriterImpl(*this)) {}
1497 
1499  Operation *op, ValueRange newValues, bool *allUsesReplaced,
1500  llvm::unique_function<bool(OpOperand &) const> functor) {
1501  // TODO: To support this we will need to rework a bit of how replacements are
1502  // tracked, given that this isn't guranteed to replace all of the uses of an
1503  // operation. The main change is that now an operation can be replaced
1504  // multiple times, in parts. The current "set" based tracking is mainly useful
1505  // for tracking if a replaced operation should be ignored, i.e. if all of the
1506  // uses will be replaced.
1507  llvm_unreachable(
1508  "replaceOpWithIf is currently not supported by DialectConversion");
1509 }
1510 
1512  LLVM_DEBUG({
1513  impl->logger.startLine()
1514  << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1515  });
1516  impl->notifyOpReplaced(op, newValues);
1517 }
1518 
1520  LLVM_DEBUG({
1521  impl->logger.startLine()
1522  << "** Erase : '" << op->getName() << "'(" << op << ")\n";
1523  });
1524  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1525  impl->notifyOpReplaced(op, nullRepls);
1526 }
1527 
1529  impl->notifyBlockIsBeingErased(block);
1530 
1531  // Mark all ops for erasure.
1532  for (Operation &op : *block)
1533  eraseOp(&op);
1534 
1535  // Unlink the block from its parent region. The block is kept in the block
1536  // action and will be actually destroyed when rewrites are applied. This
1537  // allows us to keep the operations in the block live and undo the removal by
1538  // re-inserting the block.
1539  block->getParent()->getBlocks().remove(block);
1540 }
1541 
1543  Region *region, TypeConverter::SignatureConversion &conversion,
1544  TypeConverter *converter) {
1545  return impl->applySignatureConversion(region, conversion, converter);
1546 }
1547 
1549  Region *region, TypeConverter &converter,
1550  TypeConverter::SignatureConversion *entryConversion) {
1551  return impl->convertRegionTypes(region, converter, entryConversion);
1552 }
1553 
1555  Region *region, TypeConverter &converter,
1557  return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
1558 }
1559 
1561  Value to) {
1562  LLVM_DEBUG({
1563  Operation *parentOp = from.getOwner()->getParentOp();
1564  impl->logger.startLine() << "** Replace Argument : '" << from
1565  << "'(in region of '" << parentOp->getName()
1566  << "'(" << from.getOwner()->getParentOp() << ")\n";
1567  });
1568  impl->argReplacements.push_back(from);
1569  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1570 }
1571 
1573  SmallVector<Value> remappedValues;
1574  if (failed(impl->remapValues("value", /*inputLoc=*/llvm::None, *this, key,
1575  remappedValues)))
1576  return nullptr;
1577  return remappedValues.front();
1578 }
1579 
1582  SmallVectorImpl<Value> &results) {
1583  if (keys.empty())
1584  return success();
1585  return impl->remapValues("value", /*inputLoc=*/llvm::None, *this, keys,
1586  results);
1587 }
1588 
1590  impl->notifyCreatedBlock(block);
1591 }
1592 
1594  Block::iterator before) {
1595  auto *continuation = PatternRewriter::splitBlock(block, before);
1596  impl->notifySplitBlock(block, continuation);
1597  return continuation;
1598 }
1599 
1601  ValueRange argValues) {
1602  impl->notifyBlocksBeingMerged(dest, source);
1603  assert(llvm::all_of(source->getPredecessors(),
1604  [dest](Block *succ) { return succ == dest; }) &&
1605  "expected 'source' to have no predecessors or only 'dest'");
1606  assert(argValues.size() == source->getNumArguments() &&
1607  "incorrect # of argument replacement values");
1608  for (auto it : llvm::zip(source->getArguments(), argValues))
1609  replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1610  dest->getOperations().splice(dest->end(), source->getOperations());
1611  eraseBlock(source);
1612 }
1613 
1615  Region &parent,
1616  Region::iterator before) {
1617  impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1618  PatternRewriter::inlineRegionBefore(region, parent, before);
1619 }
1620 
1622  Region &region, Region &parent, Region::iterator before,
1623  BlockAndValueMapping &mapping) {
1624  if (region.empty())
1625  return;
1626  PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1627 
1628  // Collect the range of the cloned blocks.
1629  auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1630  auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1631  impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1632 }
1633 
1635  LLVM_DEBUG({
1636  impl->logger.startLine()
1637  << "** Insert : '" << op->getName() << "'(" << op << ")\n";
1638  });
1639  impl->createdOps.push_back(op);
1640 }
1641 
1643 #ifndef NDEBUG
1644  impl->pendingRootUpdates.insert(op);
1645 #endif
1646  impl->rootUpdates.emplace_back(op);
1647 }
1648 
1650  // There is nothing to do here, we only need to track the operation at the
1651  // start of the update.
1652 #ifndef NDEBUG
1653  assert(impl->pendingRootUpdates.erase(op) &&
1654  "operation did not have a pending in-place update");
1655 #endif
1656 }
1657 
1659 #ifndef NDEBUG
1660  assert(impl->pendingRootUpdates.erase(op) &&
1661  "operation did not have a pending in-place update");
1662 #endif
1663  // Erase the last update for this operation.
1664  auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1665  auto &rootUpdates = impl->rootUpdates;
1666  auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1667  assert(it != rootUpdates.rend() && "no root update started on op");
1668  (*it).resetOperation();
1669  int updateIdx = std::prev(rootUpdates.rend()) - it;
1670  rootUpdates.erase(rootUpdates.begin() + updateIdx);
1671 }
1672 
1674  Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1675  return impl->notifyMatchFailure(loc, reasonCallback);
1676 }
1677 
1679  return *impl;
1680 }
1681 
1682 //===----------------------------------------------------------------------===//
1683 // ConversionPattern
1684 //===----------------------------------------------------------------------===//
1685 
1688  PatternRewriter &rewriter) const {
1689  auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1690  auto &rewriterImpl = dialectRewriter.getImpl();
1691 
1692  // Track the current conversion pattern type converter in the rewriter.
1693  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1694  getTypeConverter());
1695 
1696  // Remap the operands of the operation.
1697  SmallVector<Value, 4> operands;
1698  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1699  op->getOperands(), operands))) {
1700  return failure();
1701  }
1702  return matchAndRewrite(op, operands, dialectRewriter);
1703 }
1704 
1705 //===----------------------------------------------------------------------===//
1706 // OperationLegalizer
1707 //===----------------------------------------------------------------------===//
1708 
1709 namespace {
1710 /// A set of rewrite patterns that can be used to legalize a given operation.
1711 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1712 
1713 /// This class defines a recursive operation legalizer.
1714 class OperationLegalizer {
1715 public:
1716  using LegalizationAction = ConversionTarget::LegalizationAction;
1717 
1718  OperationLegalizer(ConversionTarget &targetInfo,
1719  const FrozenRewritePatternSet &patterns);
1720 
1721  /// Returns true if the given operation is known to be illegal on the target.
1722  bool isIllegal(Operation *op) const;
1723 
1724  /// Attempt to legalize the given operation. Returns success if the operation
1725  /// was legalized, failure otherwise.
1726  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1727 
1728  /// Returns the conversion target in use by the legalizer.
1729  ConversionTarget &getTarget() { return target; }
1730 
1731 private:
1732  /// Attempt to legalize the given operation by folding it.
1733  LogicalResult legalizeWithFold(Operation *op,
1734  ConversionPatternRewriter &rewriter);
1735 
1736  /// Attempt to legalize the given operation by applying a pattern. Returns
1737  /// success if the operation was legalized, failure otherwise.
1738  LogicalResult legalizeWithPattern(Operation *op,
1739  ConversionPatternRewriter &rewriter);
1740 
1741  /// Return true if the given pattern may be applied to the given operation,
1742  /// false otherwise.
1743  bool canApplyPattern(Operation *op, const Pattern &pattern,
1744  ConversionPatternRewriter &rewriter);
1745 
1746  /// Legalize the resultant IR after successfully applying the given pattern.
1747  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1748  ConversionPatternRewriter &rewriter,
1749  RewriterState &curState);
1750 
1751  /// Legalizes the actions registered during the execution of a pattern.
1752  LogicalResult legalizePatternBlockActions(Operation *op,
1753  ConversionPatternRewriter &rewriter,
1755  RewriterState &state,
1756  RewriterState &newState);
1757  LogicalResult legalizePatternCreatedOperations(
1759  RewriterState &state, RewriterState &newState);
1760  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1762  RewriterState &state,
1763  RewriterState &newState);
1764 
1765  //===--------------------------------------------------------------------===//
1766  // Cost Model
1767  //===--------------------------------------------------------------------===//
1768 
1769  /// Build an optimistic legalization graph given the provided patterns. This
1770  /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1771  /// patterns for operations that are not directly legal, but may be
1772  /// transitively legal for the current target given the provided patterns.
1773  void buildLegalizationGraph(
1774  LegalizationPatterns &anyOpLegalizerPatterns,
1776 
1777  /// Compute the benefit of each node within the computed legalization graph.
1778  /// This orders the patterns within 'legalizerPatterns' based upon two
1779  /// criteria:
1780  /// 1) Prefer patterns that have the lowest legalization depth, i.e.
1781  /// represent the more direct mapping to the target.
1782  /// 2) When comparing patterns with the same legalization depth, prefer the
1783  /// pattern with the highest PatternBenefit. This allows for users to
1784  /// prefer specific legalizations over others.
1785  void computeLegalizationGraphBenefit(
1786  LegalizationPatterns &anyOpLegalizerPatterns,
1788 
1789  /// Compute the legalization depth when legalizing an operation of the given
1790  /// type.
1791  unsigned computeOpLegalizationDepth(
1792  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1794 
1795  /// Apply the conversion cost model to the given set of patterns, and return
1796  /// the smallest legalization depth of any of the patterns. See
1797  /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1798  unsigned applyCostModelToPatterns(
1799  LegalizationPatterns &patterns,
1800  DenseMap<OperationName, unsigned> &minOpPatternDepth,
1802 
1803  /// The current set of patterns that have been applied.
1804  SmallPtrSet<const Pattern *, 8> appliedPatterns;
1805 
1806  /// The legalization information provided by the target.
1807  ConversionTarget &target;
1808 
1809  /// The pattern applicator to use for conversions.
1810  PatternApplicator applicator;
1811 };
1812 } // namespace
1813 
1814 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1815  const FrozenRewritePatternSet &patterns)
1816  : target(targetInfo), applicator(patterns) {
1817  // The set of patterns that can be applied to illegal operations to transform
1818  // them into legal ones.
1820  LegalizationPatterns anyOpLegalizerPatterns;
1821 
1822  buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1823  computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1824 }
1825 
1826 bool OperationLegalizer::isIllegal(Operation *op) const {
1827  return target.isIllegal(op);
1828 }
1829 
1831 OperationLegalizer::legalize(Operation *op,
1832  ConversionPatternRewriter &rewriter) {
1833 #ifndef NDEBUG
1834  const char *logLineComment =
1835  "//===-------------------------------------------===//\n";
1836 
1837  auto &logger = rewriter.getImpl().logger;
1838 #endif
1839  LLVM_DEBUG({
1840  logger.getOStream() << "\n";
1841  logger.startLine() << logLineComment;
1842  logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
1843  << op << ") {\n";
1844  logger.indent();
1845 
1846  // If the operation has no regions, just print it here.
1847  if (op->getNumRegions() == 0) {
1848  op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1849  logger.getOStream() << "\n\n";
1850  }
1851  });
1852 
1853  // Check if this operation is legal on the target.
1854  if (auto legalityInfo = target.isLegal(op)) {
1855  LLVM_DEBUG({
1856  logSuccess(
1857  logger, "operation marked legal by the target{0}",
1858  legalityInfo->isRecursivelyLegal
1859  ? "; NOTE: operation is recursively legal; skipping internals"
1860  : "");
1861  logger.startLine() << logLineComment;
1862  });
1863 
1864  // If this operation is recursively legal, mark its children as ignored so
1865  // that we don't consider them for legalization.
1866  if (legalityInfo->isRecursivelyLegal)
1867  rewriter.getImpl().markNestedOpsIgnored(op);
1868  return success();
1869  }
1870 
1871  // Check to see if the operation is ignored and doesn't need to be converted.
1872  if (rewriter.getImpl().isOpIgnored(op)) {
1873  LLVM_DEBUG({
1874  logSuccess(logger, "operation marked 'ignored' during conversion");
1875  logger.startLine() << logLineComment;
1876  });
1877  return success();
1878  }
1879 
1880  // If the operation isn't legal, try to fold it in-place.
1881  // TODO: Should we always try to do this, even if the op is
1882  // already legal?
1883  if (succeeded(legalizeWithFold(op, rewriter))) {
1884  LLVM_DEBUG({
1885  logSuccess(logger, "operation was folded");
1886  logger.startLine() << logLineComment;
1887  });
1888  return success();
1889  }
1890 
1891  // Otherwise, we need to apply a legalization pattern to this operation.
1892  if (succeeded(legalizeWithPattern(op, rewriter))) {
1893  LLVM_DEBUG({
1894  logSuccess(logger, "");
1895  logger.startLine() << logLineComment;
1896  });
1897  return success();
1898  }
1899 
1900  LLVM_DEBUG({
1901  logFailure(logger, "no matched legalization pattern");
1902  logger.startLine() << logLineComment;
1903  });
1904  return failure();
1905 }
1906 
1908 OperationLegalizer::legalizeWithFold(Operation *op,
1909  ConversionPatternRewriter &rewriter) {
1910  auto &rewriterImpl = rewriter.getImpl();
1911  RewriterState curState = rewriterImpl.getCurrentState();
1912 
1913  LLVM_DEBUG({
1914  rewriterImpl.logger.startLine() << "* Fold {\n";
1915  rewriterImpl.logger.indent();
1916  });
1917 
1918  // Try to fold the operation.
1919  SmallVector<Value, 2> replacementValues;
1920  rewriter.setInsertionPoint(op);
1921  if (failed(rewriter.tryFold(op, replacementValues))) {
1922  LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1923  return failure();
1924  }
1925 
1926  // Insert a replacement for 'op' with the folded replacement values.
1927  rewriter.replaceOp(op, replacementValues);
1928 
1929  // Recursively legalize any new constant operations.
1930  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1931  i != e; ++i) {
1932  Operation *cstOp = rewriterImpl.createdOps[i];
1933  if (failed(legalize(cstOp, rewriter))) {
1934  LLVM_DEBUG(logFailure(rewriterImpl.logger,
1935  "failed to legalize generated constant '{0}'",
1936  cstOp->getName()));
1937  rewriterImpl.resetState(curState);
1938  return failure();
1939  }
1940  }
1941 
1942  LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1943  return success();
1944 }
1945 
1947 OperationLegalizer::legalizeWithPattern(Operation *op,
1948  ConversionPatternRewriter &rewriter) {
1949  auto &rewriterImpl = rewriter.getImpl();
1950 
1951  // Functor that returns if the given pattern may be applied.
1952  auto canApply = [&](const Pattern &pattern) {
1953  return canApplyPattern(op, pattern, rewriter);
1954  };
1955 
1956  // Functor that cleans up the rewriter state after a pattern failed to match.
1957  RewriterState curState = rewriterImpl.getCurrentState();
1958  auto onFailure = [&](const Pattern &pattern) {
1959  LLVM_DEBUG({
1960  logFailure(rewriterImpl.logger, "pattern failed to match");
1961  if (rewriterImpl.notifyCallback) {
1963  diag << "Failed to apply pattern \"" << pattern.getDebugName()
1964  << "\" on op:\n"
1965  << *op;
1966  rewriterImpl.notifyCallback(diag);
1967  }
1968  });
1969  rewriterImpl.resetState(curState);
1970  appliedPatterns.erase(&pattern);
1971  };
1972 
1973  // Functor that performs additional legalization when a pattern is
1974  // successfully applied.
1975  auto onSuccess = [&](const Pattern &pattern) {
1976  auto result = legalizePatternResult(op, pattern, rewriter, curState);
1977  appliedPatterns.erase(&pattern);
1978  if (failed(result))
1979  rewriterImpl.resetState(curState);
1980  return result;
1981  };
1982 
1983  // Try to match and rewrite a pattern on this operation.
1984  return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1985  onSuccess);
1986 }
1987 
1988 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1989  ConversionPatternRewriter &rewriter) {
1990  LLVM_DEBUG({
1991  auto &os = rewriter.getImpl().logger;
1992  os.getOStream() << "\n";
1993  os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1994  llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
1995  os.getOStream() << ")' {\n";
1996  os.indent();
1997  });
1998 
1999  // Ensure that we don't cycle by not allowing the same pattern to be
2000  // applied twice in the same recursion stack if it is not known to be safe.
2001  if (!pattern.hasBoundedRewriteRecursion() &&
2002  !appliedPatterns.insert(&pattern).second) {
2003  LLVM_DEBUG(
2004  logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2005  return false;
2006  }
2007  return true;
2008 }
2009 
2011 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2012  ConversionPatternRewriter &rewriter,
2013  RewriterState &curState) {
2014  auto &impl = rewriter.getImpl();
2015 
2016 #ifndef NDEBUG
2017  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2018 #endif
2019 
2020  // Check that the root was either replaced or updated in place.
2021  auto replacedRoot = [&] {
2022  return llvm::any_of(
2023  llvm::drop_begin(impl.replacements, curState.numReplacements),
2024  [op](auto &it) { return it.first == op; });
2025  };
2026  auto updatedRootInPlace = [&] {
2027  return llvm::any_of(
2028  llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
2029  [op](auto &state) { return state.getOperation() == op; });
2030  };
2031  (void)replacedRoot;
2032  (void)updatedRootInPlace;
2033  assert((replacedRoot() || updatedRootInPlace()) &&
2034  "expected pattern to replace the root operation");
2035 
2036  // Legalize each of the actions registered during application.
2037  RewriterState newState = impl.getCurrentState();
2038  if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
2039  newState)) ||
2040  failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2041  failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2042  newState))) {
2043  return failure();
2044  }
2045 
2046  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2047  return success();
2048 }
2049 
2050 LogicalResult OperationLegalizer::legalizePatternBlockActions(
2051  Operation *op, ConversionPatternRewriter &rewriter,
2052  ConversionPatternRewriterImpl &impl, RewriterState &state,
2053  RewriterState &newState) {
2054  SmallPtrSet<Operation *, 16> operationsToIgnore;
2055 
2056  // If the pattern moved or created any blocks, make sure the types of block
2057  // arguments get legalized.
2058  for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
2059  ++i) {
2060  auto &action = impl.blockActions[i];
2061  if (action.kind == BlockActionKind::TypeConversion ||
2062  action.kind == BlockActionKind::Erase)
2063  continue;
2064  // Only check blocks outside of the current operation.
2065  Operation *parentOp = action.block->getParentOp();
2066  if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
2067  continue;
2068 
2069  // If the region of the block has a type converter, try to convert the block
2070  // directly.
2071  if (auto *converter =
2072  impl.argConverter.getConverter(action.block->getParent())) {
2073  if (failed(impl.convertBlockSignature(action.block, converter))) {
2074  LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2075  "block"));
2076  return failure();
2077  }
2078  continue;
2079  }
2080 
2081  // Otherwise, check that this operation isn't one generated by this pattern.
2082  // This is because we will attempt to legalize the parent operation, and
2083  // blocks in regions created by this pattern will already be legalized later
2084  // on. If we haven't built the set yet, build it now.
2085  if (operationsToIgnore.empty()) {
2086  auto createdOps = ArrayRef<Operation *>(impl.createdOps)
2087  .drop_front(state.numCreatedOps);
2088  operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2089  }
2090 
2091  // If this operation should be considered for re-legalization, try it.
2092  if (operationsToIgnore.insert(parentOp).second &&
2093  failed(legalize(parentOp, rewriter))) {
2094  LLVM_DEBUG(logFailure(
2095  impl.logger, "operation '{0}'({1}) became illegal after block action",
2096  parentOp->getName(), parentOp));
2097  return failure();
2098  }
2099  }
2100  return success();
2101 }
2102 
2103 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2105  RewriterState &state, RewriterState &newState) {
2106  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2107  Operation *op = impl.createdOps[i];
2108  if (failed(legalize(op, rewriter))) {
2109  LLVM_DEBUG(logFailure(impl.logger,
2110  "failed to legalize generated operation '{0}'({1})",
2111  op->getName(), op));
2112  return failure();
2113  }
2114  }
2115  return success();
2116 }
2117 
2118 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2120  RewriterState &state, RewriterState &newState) {
2121  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
2122  Operation *op = impl.rootUpdates[i].getOperation();
2123  if (failed(legalize(op, rewriter))) {
2124  LLVM_DEBUG(logFailure(
2125  impl.logger, "failed to legalize operation updated in-place '{0}'",
2126  op->getName()));
2127  return failure();
2128  }
2129  }
2130  return success();
2131 }
2132 
2133 //===----------------------------------------------------------------------===//
2134 // Cost Model
2135 
2136 void OperationLegalizer::buildLegalizationGraph(
2137  LegalizationPatterns &anyOpLegalizerPatterns,
2138  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2139  // A mapping between an operation and a set of operations that can be used to
2140  // generate it.
2142  // A mapping between an operation and any currently invalid patterns it has.
2144  // A worklist of patterns to consider for legality.
2145  SetVector<const Pattern *> patternWorklist;
2146 
2147  // Build the mapping from operations to the parent ops that may generate them.
2148  applicator.walkAllPatterns([&](const Pattern &pattern) {
2149  Optional<OperationName> root = pattern.getRootKind();
2150 
2151  // If the pattern has no specific root, we can't analyze the relationship
2152  // between the root op and generated operations. Given that, add all such
2153  // patterns to the legalization set.
2154  if (!root) {
2155  anyOpLegalizerPatterns.push_back(&pattern);
2156  return;
2157  }
2158 
2159  // Skip operations that are always known to be legal.
2160  if (target.getOpAction(*root) == LegalizationAction::Legal)
2161  return;
2162 
2163  // Add this pattern to the invalid set for the root op and record this root
2164  // as a parent for any generated operations.
2165  invalidPatterns[*root].insert(&pattern);
2166  for (auto op : pattern.getGeneratedOps())
2167  parentOps[op].insert(*root);
2168 
2169  // Add this pattern to the worklist.
2170  patternWorklist.insert(&pattern);
2171  });
2172 
2173  // If there are any patterns that don't have a specific root kind, we can't
2174  // make direct assumptions about what operations will never be legalized.
2175  // Note: Technically we could, but it would require an analysis that may
2176  // recurse into itself. It would be better to perform this kind of filtering
2177  // at a higher level than here anyways.
2178  if (!anyOpLegalizerPatterns.empty()) {
2179  for (const Pattern *pattern : patternWorklist)
2180  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2181  return;
2182  }
2183 
2184  while (!patternWorklist.empty()) {
2185  auto *pattern = patternWorklist.pop_back_val();
2186 
2187  // Check to see if any of the generated operations are invalid.
2188  if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2189  Optional<LegalizationAction> action = target.getOpAction(op);
2190  return !legalizerPatterns.count(op) &&
2191  (!action || action == LegalizationAction::Illegal);
2192  }))
2193  continue;
2194 
2195  // Otherwise, if all of the generated operation are valid, this op is now
2196  // legal so add all of the child patterns to the worklist.
2197  legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2198  invalidPatterns[*pattern->getRootKind()].erase(pattern);
2199 
2200  // Add any invalid patterns of the parent operations to see if they have now
2201  // become legal.
2202  for (auto op : parentOps[*pattern->getRootKind()])
2203  patternWorklist.set_union(invalidPatterns[op]);
2204  }
2205 }
2206 
2207 void OperationLegalizer::computeLegalizationGraphBenefit(
2208  LegalizationPatterns &anyOpLegalizerPatterns,
2209  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2210  // The smallest pattern depth, when legalizing an operation.
2211  DenseMap<OperationName, unsigned> minOpPatternDepth;
2212 
2213  // For each operation that is transitively legal, compute a cost for it.
2214  for (auto &opIt : legalizerPatterns)
2215  if (!minOpPatternDepth.count(opIt.first))
2216  computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2217  legalizerPatterns);
2218 
2219  // Apply the cost model to the patterns that can match any operation. Those
2220  // with a specific operation type are already resolved when computing the op
2221  // legalization depth.
2222  if (!anyOpLegalizerPatterns.empty())
2223  applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2224  legalizerPatterns);
2225 
2226  // Apply a cost model to the pattern applicator. We order patterns first by
2227  // depth then benefit. `legalizerPatterns` contains per-op patterns by
2228  // decreasing benefit.
2229  applicator.applyCostModel([&](const Pattern &pattern) {
2230  ArrayRef<const Pattern *> orderedPatternList;
2231  if (Optional<OperationName> rootName = pattern.getRootKind())
2232  orderedPatternList = legalizerPatterns[*rootName];
2233  else
2234  orderedPatternList = anyOpLegalizerPatterns;
2235 
2236  // If the pattern is not found, then it was removed and cannot be matched.
2237  auto *it = llvm::find(orderedPatternList, &pattern);
2238  if (it == orderedPatternList.end())
2240 
2241  // Patterns found earlier in the list have higher benefit.
2242  return PatternBenefit(std::distance(it, orderedPatternList.end()));
2243  });
2244 }
2245 
2246 unsigned OperationLegalizer::computeOpLegalizationDepth(
2247  OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2248  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2249  // Check for existing depth.
2250  auto depthIt = minOpPatternDepth.find(op);
2251  if (depthIt != minOpPatternDepth.end())
2252  return depthIt->second;
2253 
2254  // If a mapping for this operation does not exist, then this operation
2255  // is always legal. Return 0 as the depth for a directly legal operation.
2256  auto opPatternsIt = legalizerPatterns.find(op);
2257  if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2258  return 0u;
2259 
2260  // Record this initial depth in case we encounter this op again when
2261  // recursively computing the depth.
2262  minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2263 
2264  // Apply the cost model to the operation patterns, and update the minimum
2265  // depth.
2266  unsigned minDepth = applyCostModelToPatterns(
2267  opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2268  minOpPatternDepth[op] = minDepth;
2269  return minDepth;
2270 }
2271 
2272 unsigned OperationLegalizer::applyCostModelToPatterns(
2273  LegalizationPatterns &patterns,
2274  DenseMap<OperationName, unsigned> &minOpPatternDepth,
2275  DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2276  unsigned minDepth = std::numeric_limits<unsigned>::max();
2277 
2278  // Compute the depth for each pattern within the set.
2280  patternsByDepth.reserve(patterns.size());
2281  for (const Pattern *pattern : patterns) {
2282  unsigned depth = 1;
2283  for (auto generatedOp : pattern->getGeneratedOps()) {
2284  unsigned generatedOpDepth = computeOpLegalizationDepth(
2285  generatedOp, minOpPatternDepth, legalizerPatterns);
2286  depth = std::max(depth, generatedOpDepth + 1);
2287  }
2288  patternsByDepth.emplace_back(pattern, depth);
2289 
2290  // Update the minimum depth of the pattern list.
2291  minDepth = std::min(minDepth, depth);
2292  }
2293 
2294  // If the operation only has one legalization pattern, there is no need to
2295  // sort them.
2296  if (patternsByDepth.size() == 1)
2297  return minDepth;
2298 
2299  // Sort the patterns by those likely to be the most beneficial.
2300  llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2301  [](const std::pair<const Pattern *, unsigned> *lhs,
2302  const std::pair<const Pattern *, unsigned> *rhs) {
2303  // First sort by the smaller pattern legalization
2304  // depth.
2305  if (lhs->second != rhs->second)
2306  return llvm::array_pod_sort_comparator<unsigned>(
2307  &lhs->second, &rhs->second);
2308 
2309  // Then sort by the larger pattern benefit.
2310  auto lhsBenefit = lhs->first->getBenefit();
2311  auto rhsBenefit = rhs->first->getBenefit();
2312  return llvm::array_pod_sort_comparator<PatternBenefit>(
2313  &rhsBenefit, &lhsBenefit);
2314  });
2315 
2316  // Update the legalization pattern to use the new sorted list.
2317  patterns.clear();
2318  for (auto &patternIt : patternsByDepth)
2319  patterns.push_back(patternIt.first);
2320  return minDepth;
2321 }
2322 
2323 //===----------------------------------------------------------------------===//
2324 // OperationConverter
2325 //===----------------------------------------------------------------------===//
2326 namespace {
2327 enum OpConversionMode {
2328  /// In this mode, the conversion will ignore failed conversions to allow
2329  /// illegal operations to co-exist in the IR.
2330  Partial,
2331 
2332  /// In this mode, all operations must be legal for the given target for the
2333  /// conversion to succeed.
2334  Full,
2335 
2336  /// In this mode, operations are analyzed for legality. No actual rewrites are
2337  /// applied to the operations on success.
2338  Analysis,
2339 };
2340 
2341 // This class converts operations to a given conversion target via a set of
2342 // rewrite patterns. The conversion behaves differently depending on the
2343 // conversion mode.
2344 struct OperationConverter {
2345  explicit OperationConverter(ConversionTarget &target,
2346  const FrozenRewritePatternSet &patterns,
2347  OpConversionMode mode,
2348  DenseSet<Operation *> *trackedOps = nullptr)
2349  : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2350 
2351  /// Converts the given operations to the conversion target.
2353  convertOperations(ArrayRef<Operation *> ops,
2354  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
2355 
2356 private:
2357  /// Converts an operation with the given rewriter.
2358  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2359 
2360  /// This method is called after the conversion process to legalize any
2361  /// remaining artifacts and complete the conversion.
2362  LogicalResult finalize(ConversionPatternRewriter &rewriter);
2363 
2364  /// Legalize the types of converted block arguments.
2366  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2367  ConversionPatternRewriterImpl &rewriterImpl);
2368 
2369  /// Legalize any unresolved type materializations.
2370  LogicalResult legalizeUnresolvedMaterializations(
2371  ConversionPatternRewriter &rewriter,
2372  ConversionPatternRewriterImpl &rewriterImpl,
2373  Optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2374 
2375  /// Legalize an operation result that was marked as "erased".
2377  legalizeErasedResult(Operation *op, OpResult result,
2378  ConversionPatternRewriterImpl &rewriterImpl);
2379 
2380  /// Legalize an operation result that was replaced with a value of a different
2381  /// type.
2382  LogicalResult legalizeChangedResultType(
2383  Operation *op, OpResult result, Value newValue,
2384  TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2385  ConversionPatternRewriterImpl &rewriterImpl,
2386  const DenseMap<Value, SmallVector<Value>> &inverseMapping);
2387 
2388  /// The legalizer to use when converting operations.
2389  OperationLegalizer opLegalizer;
2390 
2391  /// The conversion mode to use when legalizing operations.
2392  OpConversionMode mode;
2393 
2394  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2395  /// this is populated with ops found to be legalizable to the target.
2396  /// When mode == OpConversionMode::Partial, this is populated with ops found
2397  /// *not* to be legalizable to the target.
2398  DenseSet<Operation *> *trackedOps;
2399 };
2400 } // namespace
2401 
2402 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2403  Operation *op) {
2404  // Legalize the given operation.
2405  if (failed(opLegalizer.legalize(op, rewriter))) {
2406  // Handle the case of a failed conversion for each of the different modes.
2407  // Full conversions expect all operations to be converted.
2408  if (mode == OpConversionMode::Full)
2409  return op->emitError()
2410  << "failed to legalize operation '" << op->getName() << "'";
2411  // Partial conversions allow conversions to fail iff the operation was not
2412  // explicitly marked as illegal. If the user provided a nonlegalizableOps
2413  // set, non-legalizable ops are included.
2414  if (mode == OpConversionMode::Partial) {
2415  if (opLegalizer.isIllegal(op))
2416  return op->emitError()
2417  << "failed to legalize operation '" << op->getName()
2418  << "' that was explicitly marked illegal";
2419  if (trackedOps)
2420  trackedOps->insert(op);
2421  }
2422  } else if (mode == OpConversionMode::Analysis) {
2423  // Analysis conversions don't fail if any operations fail to legalize,
2424  // they are only interested in the operations that were successfully
2425  // legalized.
2426  trackedOps->insert(op);
2427  }
2428  return success();
2429 }
2430 
2431 LogicalResult OperationConverter::convertOperations(
2433  function_ref<void(Diagnostic &)> notifyCallback) {
2434  if (ops.empty())
2435  return success();
2436  ConversionTarget &target = opLegalizer.getTarget();
2437 
2438  // Compute the set of operations and blocks to convert.
2439  SmallVector<Operation *> toConvert;
2440  for (auto *op : ops) {
2441  toConvert.emplace_back(op);
2442  for (auto &region : op->getRegions())
2443  if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2444  toConvert, &target)))
2445  return failure();
2446  }
2447 
2448  // Convert each operation and discard rewrites on failure.
2449  ConversionPatternRewriter rewriter(ops.front()->getContext());
2450  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2451  rewriterImpl.notifyCallback = notifyCallback;
2452 
2453  for (auto *op : toConvert)
2454  if (failed(convert(rewriter, op)))
2455  return rewriterImpl.discardRewrites(), failure();
2456 
2457  // Now that all of the operations have been converted, finalize the conversion
2458  // process to ensure any lingering conversion artifacts are cleaned up and
2459  // legalized.
2460  if (failed(finalize(rewriter)))
2461  return rewriterImpl.discardRewrites(), failure();
2462 
2463  // After a successful conversion, apply rewrites if this is not an analysis
2464  // conversion.
2465  if (mode == OpConversionMode::Analysis) {
2466  rewriterImpl.discardRewrites();
2467  } else {
2468  rewriterImpl.applyRewrites();
2469 
2470  // It is possible for a later pattern to erase an op that was originally
2471  // identified as illegal and added to the trackedOps, remove it now after
2472  // replacements have been computed.
2473  if (trackedOps)
2474  for (auto &repl : rewriterImpl.replacements)
2475  trackedOps->erase(repl.first);
2476  }
2477  return success();
2478 }
2479 
2481 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2483  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2484  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2485  inverseMapping)) ||
2486  failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2487  return failure();
2488 
2489  if (rewriterImpl.operationsWithChangedResults.empty())
2490  return success();
2491 
2492  // Process requested operation replacements.
2493  for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2494  i != e; ++i) {
2495  unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2496  auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2497  for (OpResult result : repl.first->getResults()) {
2498  Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2499 
2500  // If the operation result was replaced with null, all of the uses of this
2501  // value should be replaced.
2502  if (!newValue) {
2503  if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2504  return failure();
2505  continue;
2506  }
2507 
2508  // Otherwise, check to see if the type of the result changed.
2509  if (result.getType() == newValue.getType())
2510  continue;
2511 
2512  // Compute the inverse mapping only if it is really needed.
2513  if (!inverseMapping)
2514  inverseMapping = rewriterImpl.mapping.getInverse();
2515 
2516  // Legalize this result.
2517  rewriter.setInsertionPoint(repl.first);
2518  if (failed(legalizeChangedResultType(repl.first, result, newValue,
2519  repl.second.converter, rewriter,
2520  rewriterImpl, *inverseMapping)))
2521  return failure();
2522 
2523  // Update the end iterator for this loop in the case it was updated
2524  // when legalizing generated conversion operations.
2525  e = rewriterImpl.operationsWithChangedResults.size();
2526  }
2527  }
2528  return success();
2529 }
2530 
2531 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2532  ConversionPatternRewriter &rewriter,
2533  ConversionPatternRewriterImpl &rewriterImpl) {
2534  // Functor used to check if all users of a value will be dead after
2535  // conversion.
2536  auto findLiveUser = [&](Value val) {
2537  auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2538  return rewriterImpl.isOpIgnored(user);
2539  });
2540  return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2541  };
2542  return rewriterImpl.argConverter.materializeLiveConversions(
2543  rewriterImpl.mapping, rewriter, findLiveUser);
2544 }
2545 
2546 /// Replace the results of a materialization operation with the given values.
2547 static void
2549  ResultRange matResults, ValueRange values,
2550  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2551  matResults.replaceAllUsesWith(values);
2552 
2553  // For each of the materialization results, update the inverse mappings to
2554  // point to the replacement values.
2555  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
2556  auto inverseMapIt = inverseMapping.find(matResult);
2557  if (inverseMapIt == inverseMapping.end())
2558  continue;
2559 
2560  // Update the reverse mapping, or remove the mapping if we couldn't update
2561  // it. Not being able to update signals that the mapping would have become
2562  // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
2563  // propagated through temporary materializations. We simply drop the
2564  // mapping, and let the post-conversion replacement logic handle updating
2565  // uses.
2566  for (Value inverseMapVal : inverseMapIt->second)
2567  if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
2568  rewriterImpl.mapping.erase(inverseMapVal);
2569  }
2570 }
2571 
2572 /// Compute all of the unresolved materializations that will persist beyond the
2573 /// conversion process, and require inserting a proper user materialization for.
2576  ConversionPatternRewriter &rewriter,
2577  ConversionPatternRewriterImpl &rewriterImpl,
2578  DenseMap<Value, SmallVector<Value>> &inverseMapping,
2579  SetVector<UnresolvedMaterialization *> &necessaryMaterializations) {
2580  auto isLive = [&](Value value) {
2581  auto findFn = [&](Operation *user) {
2582  auto matIt = materializationOps.find(user);
2583  if (matIt != materializationOps.end())
2584  return !necessaryMaterializations.count(matIt->second);
2585  return rewriterImpl.isOpIgnored(user);
2586  };
2587  // This value may be replacing another value that has a live user.
2588  for (Value inv : inverseMapping.lookup(value))
2589  if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2590  return true;
2591  // Or have live users itself.
2592  return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2593  };
2594 
2595  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
2596  [&](Value invalidRoot, Value value, Type type) {
2597  // Check to see if the input operation was remapped to a variant of the
2598  // output.
2599  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2600  if (remappedValue.getType() == type && remappedValue != invalidRoot)
2601  return remappedValue;
2602 
2603  // Check to see if the input is a materialization operation that
2604  // provides an inverse conversion. We just check blindly for
2605  // UnrealizedConversionCastOp here, but it has no effect on correctness.
2606  auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
2607  if (inputCastOp && inputCastOp->getNumOperands() == 1)
2608  return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2609  type);
2610 
2611  return Value();
2612  };
2613 
2615  for (auto &mat : rewriterImpl.unresolvedMaterializations) {
2616  materializationOps.try_emplace(mat.getOp(), &mat);
2617  worklist.insert(&mat);
2618  }
2619  while (!worklist.empty()) {
2620  UnresolvedMaterialization *mat = worklist.pop_back_val();
2621  UnrealizedConversionCastOp op = mat->getOp();
2622 
2623  // We currently only handle target materializations here.
2624  assert(op->getNumResults() == 1 && "unexpected materialization type");
2625  OpResult opResult = op->getOpResult(0);
2626  Type outputType = opResult.getType();
2627  Operation::operand_range inputOperands = op.getOperands();
2628 
2629  // Try to forward propagate operands for user conversion casts that result
2630  // in the input types of the current cast.
2631  for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
2632  auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2633  if (!castOp)
2634  continue;
2635  if (castOp->getResultTypes() == inputOperands.getTypes()) {
2636  replaceMaterialization(rewriterImpl, opResult, inputOperands,
2637  inverseMapping);
2638  necessaryMaterializations.remove(materializationOps.lookup(user));
2639  }
2640  }
2641 
2642  // Try to avoid materializing a resolved materialization if possible.
2643  // Handle the case of a 1-1 materialization.
2644  if (inputOperands.size() == 1) {
2645  // Check to see if the input operation was remapped to a variant of the
2646  // output.
2647  Value remappedValue =
2648  lookupRemappedValue(opResult, inputOperands[0], outputType);
2649  if (remappedValue && remappedValue != opResult) {
2650  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2651  inverseMapping);
2652  necessaryMaterializations.remove(mat);
2653  continue;
2654  }
2655  } else {
2656  // TODO: Avoid materializing other types of conversions here.
2657  }
2658 
2659  // Check to see if this is an argument materialization.
2660  auto isBlockArg = [](Value v) { return v.isa<BlockArgument>(); };
2661  if (llvm::any_of(op->getOperands(), isBlockArg) ||
2662  llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
2664  }
2665 
2666  // If the materialization does not have any live users, we don't need to
2667  // generate a user materialization for it.
2668  // FIXME: For argument materializations, we currently need to check if any
2669  // of the inverse mapped values are used because some patterns expect blind
2670  // value replacement even if the types differ in some cases. When those
2671  // patterns are fixed, we can drop the argument special case here.
2672  bool isMaterializationLive = isLive(opResult);
2673  if (mat->getKind() == UnresolvedMaterialization::Argument)
2674  isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2675  if (!isMaterializationLive)
2676  continue;
2677  if (!necessaryMaterializations.insert(mat))
2678  continue;
2679 
2680  // Reprocess input materializations to see if they have an updated status.
2681  for (Value input : inputOperands) {
2682  if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2683  if (auto *mat = materializationOps.lookup(parentOp))
2684  worklist.insert(mat);
2685  }
2686  }
2687  }
2688 }
2689 
2690 /// Legalize the given unresolved materialization. Returns success if the
2691 /// materialization was legalized, failure otherise.
2693  UnresolvedMaterialization &mat,
2695  ConversionPatternRewriter &rewriter,
2696  ConversionPatternRewriterImpl &rewriterImpl,
2697  DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2698  auto findLiveUser = [&](auto &&users) {
2699  auto liveUserIt = llvm::find_if_not(
2700  users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
2701  return liveUserIt == users.end() ? nullptr : *liveUserIt;
2702  };
2703 
2704  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
2705  [&](Value value, Type type) {
2706  // Check to see if the input operation was remapped to a variant of the
2707  // output.
2708  Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
2709  if (remappedValue.getType() == type)
2710  return remappedValue;
2711  return Value();
2712  };
2713 
2714  UnrealizedConversionCastOp op = mat.getOp();
2715  if (!rewriterImpl.ignoredOps.insert(op))
2716  return success();
2717 
2718  // We currently only handle target materializations here.
2719  OpResult opResult = op->getOpResult(0);
2720  Operation::operand_range inputOperands = op.getOperands();
2721  Type outputType = opResult.getType();
2722 
2723  // If any input to this materialization is another materialization, resolve
2724  // the input first.
2725  for (Value value : op->getOperands()) {
2726  auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
2727  if (!valueCast)
2728  continue;
2729 
2730  auto matIt = materializationOps.find(valueCast);
2731  if (matIt != materializationOps.end())
2733  *matIt->second, materializationOps, rewriter, rewriterImpl,
2734  inverseMapping)))
2735  return failure();
2736  }
2737 
2738  // Perform a last ditch attempt to avoid materializing a resolved
2739  // materialization if possible.
2740  // Handle the case of a 1-1 materialization.
2741  if (inputOperands.size() == 1) {
2742  // Check to see if the input operation was remapped to a variant of the
2743  // output.
2744  Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2745  if (remappedValue && remappedValue != opResult) {
2746  replaceMaterialization(rewriterImpl, opResult, remappedValue,
2747  inverseMapping);
2748  return success();
2749  }
2750  } else {
2751  // TODO: Avoid materializing other types of conversions here.
2752  }
2753 
2754  // Try to materialize the conversion.
2755  if (TypeConverter *converter = mat.getConverter()) {
2756  // FIXME: Determine a suitable insertion location when there are multiple
2757  // inputs.
2758  if (inputOperands.size() == 1)
2759  rewriter.setInsertionPointAfterValue(inputOperands.front());
2760  else
2761  rewriter.setInsertionPoint(op);
2762 
2763  Value newMaterialization;
2764  switch (mat.getKind()) {
2766  // Try to materialize an argument conversion.
2767  // FIXME: The current argument materialization hook expects the original
2768  // output type, even though it doesn't use that as the actual output type
2769  // of the generated IR. The output type is just used as an indicator of
2770  // the type of materialization to do. This behavior is really awkward in
2771  // that it diverges from the behavior of the other hooks, and can be
2772  // easily misunderstood. We should clean up the argument hooks to better
2773  // represent the desired invariants we actually care about.
2774  newMaterialization = converter->materializeArgumentConversion(
2775  rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2776  if (newMaterialization)
2777  break;
2778 
2779  // If an argument materialization failed, fallback to trying a target
2780  // materialization.
2781  [[fallthrough]];
2782  case UnresolvedMaterialization::Target:
2783  newMaterialization = converter->materializeTargetConversion(
2784  rewriter, op->getLoc(), outputType, inputOperands);
2785  break;
2786  }
2787  if (newMaterialization) {
2788  replaceMaterialization(rewriterImpl, opResult, newMaterialization,
2789  inverseMapping);
2790  return success();
2791  }
2792  }
2793 
2794  InFlightDiagnostic diag = op->emitError()
2795  << "failed to legalize unresolved materialization "
2796  "from "
2797  << inputOperands.getTypes() << " to " << outputType
2798  << " that remained live after conversion";
2799  if (Operation *liveUser = findLiveUser(op->getUsers())) {
2800  diag.attachNote(liveUser->getLoc())
2801  << "see existing live user here: " << *liveUser;
2802  }
2803  return failure();
2804 }
2805 
2806 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2807  ConversionPatternRewriter &rewriter,
2808  ConversionPatternRewriterImpl &rewriterImpl,
2809  Optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2810  if (rewriterImpl.unresolvedMaterializations.empty())
2811  return success();
2812  inverseMapping = rewriterImpl.mapping.getInverse();
2813 
2814  // As an initial step, compute all of the inserted materializations that we
2815  // expect to persist beyond the conversion process.
2817  SetVector<UnresolvedMaterialization *> necessaryMaterializations;
2818  computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2819  *inverseMapping, necessaryMaterializations);
2820 
2821  // Once computed, legalize any necessary materializations.
2822  for (auto *mat : necessaryMaterializations) {
2824  *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2825  return failure();
2826  }
2827  return success();
2828 }
2829 
2830 LogicalResult OperationConverter::legalizeErasedResult(
2831  Operation *op, OpResult result,
2832  ConversionPatternRewriterImpl &rewriterImpl) {
2833  // If the operation result was replaced with null, all of the uses of this
2834  // value should be replaced.
2835  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2836  return rewriterImpl.isOpIgnored(user);
2837  });
2838  if (liveUserIt != result.user_end()) {
2839  InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2840  << op->getName() << "' marked as erased";
2841  diag.attachNote(liveUserIt->getLoc())
2842  << "found live user of result #" << result.getResultNumber() << ": "
2843  << *liveUserIt;
2844  return failure();
2845  }
2846  return success();
2847 }
2848 
2849 /// Finds a user of the given value, or of any other value that the given value
2850 /// replaced, that was not replaced in the conversion process.
2852  Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2853  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2854  SmallVector<Value> worklist(1, initialValue);
2855  while (!worklist.empty()) {
2856  Value value = worklist.pop_back_val();
2857 
2858  // Walk the users of this value to see if there are any live users that
2859  // weren't replaced during conversion.
2860  auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2861  return rewriterImpl.isOpIgnored(user);
2862  });
2863  if (liveUserIt != value.user_end())
2864  return *liveUserIt;
2865  auto mapIt = inverseMapping.find(value);
2866  if (mapIt != inverseMapping.end())
2867  worklist.append(mapIt->second);
2868  }
2869  return nullptr;
2870 }
2871 
2872 LogicalResult OperationConverter::legalizeChangedResultType(
2873  Operation *op, OpResult result, Value newValue,
2874  TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2875  ConversionPatternRewriterImpl &rewriterImpl,
2876  const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2877  Operation *liveUser =
2878  findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2879  if (!liveUser)
2880  return success();
2881 
2882  // Functor used to emit a conversion error for a failed materialization.
2883  auto emitConversionError = [&] {
2885  << "failed to materialize conversion for result #"
2886  << result.getResultNumber() << " of operation '"
2887  << op->getName()
2888  << "' that remained live after conversion";
2889  diag.attachNote(liveUser->getLoc())
2890  << "see existing live user here: " << *liveUser;
2891  return failure();
2892  };
2893 
2894  // If the replacement has a type converter, attempt to materialize a
2895  // conversion back to the original type.
2896  if (!replConverter)
2897  return emitConversionError();
2898 
2899  // Materialize a conversion for this live result value.
2900  Type resultType = result.getType();
2901  Value convertedValue = replConverter->materializeSourceConversion(
2902  rewriter, op->getLoc(), resultType, newValue);
2903  if (!convertedValue)
2904  return emitConversionError();
2905 
2906  rewriterImpl.mapping.map(result, convertedValue);
2907  return success();
2908 }
2909 
2910 //===----------------------------------------------------------------------===//
2911 // Type Conversion
2912 //===----------------------------------------------------------------------===//
2913 
2915  ArrayRef<Type> types) {
2916  assert(!types.empty() && "expected valid types");
2917  remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2918  addInputs(types);
2919 }
2920 
2922  assert(!types.empty() &&
2923  "1->0 type remappings don't need to be added explicitly");
2924  argTypes.append(types.begin(), types.end());
2925 }
2926 
2927 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2928  unsigned newInputNo,
2929  unsigned newInputCount) {
2930  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2931  assert(newInputCount != 0 && "expected valid input count");
2932  remappedInputs[origInputNo] =
2933  InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2934 }
2935 
2937  Value replacementValue) {
2938  assert(!remappedInputs[origInputNo] && "input has already been remapped");
2939  remappedInputs[origInputNo] =
2940  InputMapping{origInputNo, /*size=*/0, replacementValue};
2941 }
2942 
2944  SmallVectorImpl<Type> &results) {
2945  auto existingIt = cachedDirectConversions.find(t);
2946  if (existingIt != cachedDirectConversions.end()) {
2947  if (existingIt->second)
2948  results.push_back(existingIt->second);
2949  return success(existingIt->second != nullptr);
2950  }
2951  auto multiIt = cachedMultiConversions.find(t);
2952  if (multiIt != cachedMultiConversions.end()) {
2953  results.append(multiIt->second.begin(), multiIt->second.end());
2954  return success();
2955  }
2956 
2957  // Walk the added converters in reverse order to apply the most recently
2958  // registered first.
2959  size_t currentCount = results.size();
2960  conversionCallStack.push_back(t);
2961  auto popConversionCallStack =
2962  llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
2963  for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2964  if (Optional<LogicalResult> result =
2965  converter(t, results, conversionCallStack)) {
2966  if (!succeeded(*result)) {
2967  cachedDirectConversions.try_emplace(t, nullptr);
2968  return failure();
2969  }
2970  auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2971  if (newTypes.size() == 1)
2972  cachedDirectConversions.try_emplace(t, newTypes.front());
2973  else
2974  cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2975  return success();
2976  }
2977  }
2978  return failure();
2979 }
2980 
2982  // Use the multi-type result version to convert the type.
2983  SmallVector<Type, 1> results;
2984  if (failed(convertType(t, results)))
2985  return nullptr;
2986 
2987  // Check to ensure that only one type was produced.
2988  return results.size() == 1 ? results.front() : nullptr;
2989 }
2990 
2992  SmallVectorImpl<Type> &results) {
2993  for (Type type : types)
2994  if (failed(convertType(type, results)))
2995  return failure();
2996  return success();
2997 }
2998 
2999 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
3001  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
3002 }
3003 
3005  return llvm::all_of(*region, [this](Block &block) {
3006  return isLegal(block.getArgumentTypes());
3007  });
3008 }
3009 
3010 bool TypeConverter::isSignatureLegal(FunctionType ty) {
3011  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3012 }
3013 
3015  SignatureConversion &result) {
3016  // Try to convert the given input type.
3017  SmallVector<Type, 1> convertedTypes;
3018  if (failed(convertType(type, convertedTypes)))
3019  return failure();
3020 
3021  // If this argument is being dropped, there is nothing left to do.
3022  if (convertedTypes.empty())
3023  return success();
3024 
3025  // Otherwise, add the new inputs.
3026  result.addInputs(inputNo, convertedTypes);
3027  return success();
3028 }
3030  SignatureConversion &result,
3031  unsigned origInputOffset) {
3032  for (unsigned i = 0, e = types.size(); i != e; ++i)
3033  if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3034  return failure();
3035  return success();
3036 }
3037 
3038 Value TypeConverter::materializeConversion(
3040  OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
3041  for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
3042  if (Optional<Value> result = fn(builder, resultType, inputs, loc))
3043  return *result;
3044  return nullptr;
3045 }
3046 
3049  SignatureConversion conversion(block->getNumArguments());
3050  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
3051  return llvm::None;
3052  return conversion;
3053 }
3054 
3055 //===----------------------------------------------------------------------===//
3056 // FunctionOpInterfaceSignatureConversion
3057 //===----------------------------------------------------------------------===//
3058 
3059 static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3060  TypeConverter &typeConverter,
3061  ConversionPatternRewriter &rewriter) {
3062  FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
3063 
3064  // Convert the original function types.
3065  TypeConverter::SignatureConversion result(type.getNumInputs());
3066  SmallVector<Type, 1> newResults;
3067  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3068  failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3069  failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3070  typeConverter, &result)))
3071  return failure();
3072 
3073  // Update the function signature in-place.
3074  auto newType = FunctionType::get(rewriter.getContext(),
3075  result.getConvertedTypes(), newResults);
3076 
3077  rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
3078 
3079  return success();
3080 }
3081 
3082 /// Create a default conversion pattern that rewrites the type signature of a
3083 /// FunctionOpInterface op. This only supports ops which use FunctionType to
3084 /// represent their type.
3085 namespace {
3086 struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3087  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3088  MLIRContext *ctx,
3089  TypeConverter &converter)
3090  : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
3091 
3093  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3094  ConversionPatternRewriter &rewriter) const override {
3095  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3096  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3097  }
3098 };
3099 
3100 struct AnyFunctionOpInterfaceSignatureConversion
3101  : public OpInterfaceConversionPattern<FunctionOpInterface> {
3103 
3105  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3106  ConversionPatternRewriter &rewriter) const override {
3107  return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3108  }
3109 };
3110 } // namespace
3111 
3113  StringRef functionLikeOpName, RewritePatternSet &patterns,
3114  TypeConverter &converter) {
3115  patterns.add<FunctionOpInterfaceSignatureConversion>(
3116  functionLikeOpName, patterns.getContext(), converter);
3117 }
3118 
3120  RewritePatternSet &patterns, TypeConverter &converter) {
3121  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3122  converter, patterns.getContext());
3123 }
3124 
3125 //===----------------------------------------------------------------------===//
3126 // ConversionTarget
3127 //===----------------------------------------------------------------------===//
3128 
3130  LegalizationAction action) {
3131  legalOperations[op].action = action;
3132 }
3133 
3135  LegalizationAction action) {
3136  for (StringRef dialect : dialectNames)
3137  legalDialects[dialect] = action;
3138 }
3139 
3142  Optional<LegalizationInfo> info = getOpInfo(op);
3143  return info ? info->action : Optional<LegalizationAction>();
3144 }
3145 
3148  Optional<LegalizationInfo> info = getOpInfo(op->getName());
3149  if (!info)
3150  return llvm::None;
3151 
3152  // Returns true if this operation instance is known to be legal.
3153  auto isOpLegal = [&] {
3154  // Handle dynamic legality either with the provided legality function.
3155  if (info->action == LegalizationAction::Dynamic) {
3156  Optional<bool> result = info->legalityFn(op);
3157  if (result)
3158  return *result;
3159  }
3160 
3161  // Otherwise, the operation is only legal if it was marked 'Legal'.
3162  return info->action == LegalizationAction::Legal;
3163  };
3164  if (!isOpLegal())
3165  return llvm::None;
3166 
3167  // This operation is legal, compute any additional legality information.
3168  LegalOpDetails legalityDetails;
3169  if (info->isRecursivelyLegal) {
3170  auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3171  if (legalityFnIt != opRecursiveLegalityFns.end()) {
3172  legalityDetails.isRecursivelyLegal =
3173  legalityFnIt->second(op).value_or(true);
3174  } else {
3175  legalityDetails.isRecursivelyLegal = true;
3176  }
3177  }
3178  return legalityDetails;
3179 }
3180 
3182  Optional<LegalizationInfo> info = getOpInfo(op->getName());
3183  if (!info)
3184  return false;
3185 
3186  if (info->action == LegalizationAction::Dynamic) {
3187  Optional<bool> result = info->legalityFn(op);
3188  if (!result)
3189  return false;
3190 
3191  return !(*result);
3192  }
3193 
3194  return info->action == LegalizationAction::Illegal;
3195 }
3196 
3200  if (!oldCallback)
3201  return newCallback;
3202 
3203  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3204  Operation *op) -> Optional<bool> {
3205  if (Optional<bool> result = newCl(op))
3206  return *result;
3207 
3208  return oldCl(op);
3209  };
3210  return chain;
3211 }
3212 
3213 void ConversionTarget::setLegalityCallback(
3214  OperationName name, const DynamicLegalityCallbackFn &callback) {
3215  assert(callback && "expected valid legality callback");
3216  auto infoIt = legalOperations.find(name);
3217  assert(infoIt != legalOperations.end() &&
3218  infoIt->second.action == LegalizationAction::Dynamic &&
3219  "expected operation to already be marked as dynamically legal");
3220  infoIt->second.legalityFn =
3221  composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3222 }
3223 
3225  OperationName name, const DynamicLegalityCallbackFn &callback) {
3226  auto infoIt = legalOperations.find(name);
3227  assert(infoIt != legalOperations.end() &&
3228  infoIt->second.action != LegalizationAction::Illegal &&
3229  "expected operation to already be marked as legal");
3230  infoIt->second.isRecursivelyLegal = true;
3231  if (callback)
3232  opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3233  std::move(opRecursiveLegalityFns[name]), callback);
3234  else
3235  opRecursiveLegalityFns.erase(name);
3236 }
3237 
3238 void ConversionTarget::setLegalityCallback(
3239  ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3240  assert(callback && "expected valid legality callback");
3241  for (StringRef dialect : dialects)
3242  dialectLegalityFns[dialect] = composeLegalityCallbacks(
3243  std::move(dialectLegalityFns[dialect]), callback);
3244 }
3245 
3246 void ConversionTarget::setLegalityCallback(
3247  const DynamicLegalityCallbackFn &callback) {
3248  assert(callback && "expected valid legality callback");
3249  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3250 }
3251 
3252 auto ConversionTarget::getOpInfo(OperationName op) const
3254  // Check for info for this specific operation.
3255  auto it = legalOperations.find(op);
3256  if (it != legalOperations.end())
3257  return it->second;
3258  // Check for info for the parent dialect.
3259  auto dialectIt = legalDialects.find(op.getDialectNamespace());
3260  if (dialectIt != legalDialects.end()) {
3261  DynamicLegalityCallbackFn callback;
3262  auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3263  if (dialectFn != dialectLegalityFns.end())
3264  callback = dialectFn->second;
3265  return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
3266  callback};
3267  }
3268  // Otherwise, check if we mark unknown operations as dynamic.
3269  if (unknownLegalityFn)
3270  return LegalizationInfo{LegalizationAction::Dynamic,
3271  /*isRecursivelyLegal=*/false, unknownLegalityFn};
3272  return llvm::None;
3273 }
3274 
3275 //===----------------------------------------------------------------------===//
3276 // PDL Configuration
3277 //===----------------------------------------------------------------------===//
3278 
3280  auto &rewriterImpl =
3281  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3282  rewriterImpl.currentTypeConverter = getTypeConverter();
3283 }
3284 
3286  auto &rewriterImpl =
3287  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3288  rewriterImpl.currentTypeConverter = nullptr;
3289 }
3290 
3291 /// Remap the given value using the rewriter and the type converter in the
3292 /// provided config.
3295  SmallVector<Value> mappedValues;
3296  if (failed(rewriter.getRemappedValues(values, mappedValues)))
3297  return failure();
3298  return std::move(mappedValues);
3299 }
3300 
3303  "convertValue",
3304  [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
3305  auto results = pdllConvertValues(
3306  static_cast<ConversionPatternRewriter &>(rewriter), value);
3307  if (failed(results))
3308  return failure();
3309  return results->front();
3310  });
3312  "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
3313  return pdllConvertValues(
3314  static_cast<ConversionPatternRewriter &>(rewriter), values);
3315  });
3317  "convertType",
3318  [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
3319  auto &rewriterImpl =
3320  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3321  if (TypeConverter *converter = rewriterImpl.currentTypeConverter) {
3322  if (Type newType = converter->convertType(type))
3323  return newType;
3324  return failure();
3325  }
3326  return type;
3327  });
3329  "convertTypes",
3330  [](PatternRewriter &rewriter,
3332  auto &rewriterImpl =
3333  static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
3334  TypeConverter *converter = rewriterImpl.currentTypeConverter;
3335  if (!converter)
3336  return SmallVector<Type>(types);
3337 
3338  SmallVector<Type> remappedTypes;
3339  if (failed(converter->convertTypes(types, remappedTypes)))
3340  return failure();
3341  return std::move(remappedTypes);
3342  });
3343 }
3344 
3345 //===----------------------------------------------------------------------===//
3346 // Op Conversion Entry Points
3347 //===----------------------------------------------------------------------===//
3348 
3349 //===----------------------------------------------------------------------===//
3350 // Partial Conversion
3351 
3354  ConversionTarget &target,
3355  const FrozenRewritePatternSet &patterns,
3356  DenseSet<Operation *> *unconvertedOps) {
3357  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3358  unconvertedOps);
3359  return opConverter.convertOperations(ops);
3360 }
3363  const FrozenRewritePatternSet &patterns,
3364  DenseSet<Operation *> *unconvertedOps) {
3365  return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
3366  unconvertedOps);
3367 }
3368 
3369 //===----------------------------------------------------------------------===//
3370 // Full Conversion
3371 
3374  const FrozenRewritePatternSet &patterns) {
3375  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3376  return opConverter.convertOperations(ops);
3377 }
3380  const FrozenRewritePatternSet &patterns) {
3381  return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
3382 }
3383 
3384 //===----------------------------------------------------------------------===//
3385 // Analysis Conversion
3386 
3389  ConversionTarget &target,
3390  const FrozenRewritePatternSet &patterns,
3391  DenseSet<Operation *> &convertedOps,
3392  function_ref<void(Diagnostic &)> notifyCallback) {
3393  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3394  &convertedOps);
3395  return opConverter.convertOperations(ops, notifyCallback);
3396 }
3399  const FrozenRewritePatternSet &patterns,
3400  DenseSet<Operation *> &convertedOps,
3401  function_ref<void(Diagnostic &)> notifyCallback) {
3402  return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
3403  convertedOps, notifyCallback);
3404 }
static std::string diag(llvm::Value &value)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void detachNestedAndErase(Operation *op)
Detach any operations nested in the given operation from their parent blocks, and erase the given ope...
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterialization * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static LogicalResult computeConversionSet(iterator_range< Region::iterator > region, Location regionLoc, SmallVectorImpl< Operation * > &toConvert, ConversionTarget *target=nullptr)
Recursively collect all of the operations to convert from within 'region'.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static Value buildUnresolvedMaterialization(UnresolvedMaterialization::Kind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
Build an unresolved materialization operation given an output type and set of input operands.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterialization &mat, DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter, Location loc, ValueRange inputs, Type origOutputType, Type outputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static constexpr const bool value
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block * lookup(Block *from) const
Lookup a mapped value within the map.
This class represents an argument of a Block.
Definition: Value.h:296
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:305
Location getLoc() const
Return the location for this argument.
Definition: Value.h:311
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
bool empty()
Definition: Block.h:137
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
Operation & back()
Definition: Block.h:141
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:148
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:291
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
Definition: Block.cpp:82
SuccessorRange getSuccessors()
Definition: Block.h:253
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:223
OpListType & getOperations()
Definition: Block.h:126
BlockArgListType getArguments()
Definition: Block.h:76
iterator end()
Definition: Block.h:133
iterator begin()
Definition: Block.h:132
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Definition: Block.cpp:47
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
MLIRContext * getContext() const
Definition: Builders.h:54
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor) override
PatternRewriter hook for replacing the results of an operation when the given functor returns true.
void notifyBlockCreated(Block *block) override
PatternRewriter hook creating a new block.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override
PatternRewriter hook for merging a block into another.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping) override
PatternRewriter hook for cloning blocks of one region into another.
void notifyOperationInserted(Operation *op) override
PatternRewriter hook for inserting a new operation.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
Base class for the conversion patterns.
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::function< Optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
Optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
Optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a frozen set of patterns that can be processed by a pattern applicator.
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition: UseDefLists.h:179
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:188
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
Location objects represent source locations information in MLIR.
Definition: Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:397
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition: Builders.h:373
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within 'results'.
Definition: Builders.cpp:440
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:394
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
Definition: Value.h:247
This is a value defined by a result of an operation.
Definition: Value.h:442
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:454
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:41
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:157
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:611
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:265
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
operand_type_range getOperandTypes()
Definition: Operation.h:314
result_type_range getResultTypes()
Definition: Operation.h:345
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
result_range getResults()
Definition: Operation.h:332
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:574
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:418
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:41
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:127
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
Definition: PatternMatch.h:88
Optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:92
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
Block & front()
Definition: Region.h:65
BlockListType::iterator iterator
Definition: Region.h:52
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:230
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
Definition: ValueRange.h:269
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
Optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
bool isSignatureLegal(FunctionType ty)
Return true if the inputs and outputs of the given function type are legal.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result)
This method allows for converting a specific argument of a signature.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling 'results' as necessary.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0)
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Optional< SignatureConversion > convertBlockSignature(Block *block)
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:199
bool isa() const
Definition: Value.h:90
void dropAllUses() const
Drop all uses of this object from their respective owners.
Definition: Value.h:153
Type getType() const
Return the type of this value.
Definition: Value.h:114
U dyn_cast() const
Definition: Value.h:95
user_iterator user_end() const
Definition: Value.h:208
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
U cast() const
Definition: Value.h:105
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:158
user_range getUsers() const
Definition: Value.h:209
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
@ Full
Documents are synced by always sending the full content of the document.
Kind
Tensor expression kind.
Definition: Merger.h:25
llvm::PointerUnion< NamedAttribute *, NamedTypeConstraint * > Argument
Definition: Argument.h:62
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This struct represents a range of new types or a single value that remaps an existing signature input...
void eraseDanglingBlocks()
Erase any blocks that were unlinked from their regions and stored in block actions.
TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void undoBlockActions(unsigned numActionsToKeep=0)
Undo the block actions (motions, splits) one by one in reverse order until "numActionsToKeep" actions...
void discardRewrites()
Cleanup and destroy any generated rewrite operations.
ConversionPatternRewriterImpl(PatternRewriter &rewriter)
function_ref< void(Diagnostic &)> notifyCallback
This allows the user to collect the match failure message.
SmallVector< BlockAction, 4 > blockActions
Ordered list of block operations (creations, splits, motions).
llvm::MapVector< Operation *, OpReplacement > replacements
Ordered map of requested operation replacements.
void notifyBlocksBeingMerged(Block *block, Block *srcBlock)
Notifies that block is being merged with srcBlock.
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent, Region::iterator before)
Notifies that the blocks of a region are about to be moved.
void notifySplitBlock(Block *block, Block *continuation)
Notifies that a block was split.
void applyRewrites()
Apply all requested operation rewrites.
SmallVector< OperationTransactionState, 4 > rootUpdates
A transaction state for each of operations that were updated in-place.
RewriterState getCurrentState()
Return the current state of the rewriter.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
void notifyOpReplaced(Operation *op, ValueRange newValues)
PatternRewriter hook for replacing the results of an operation.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
SmallVector< UnresolvedMaterialization > unresolvedMaterializations
Ordered vector of all unresolved type conversion materializations during conversion.
ArgConverter argConverter
Utility used to convert block arguments.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
bool isOpIgnored(Operation *op) const
Returns true if the given operation is ignored, and does not need to be converted.
FailureOr< Block * > convertBlockSignature(Block *block, TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Convert the signature of the given block.
LogicalResult remapValues(StringRef valueDiagTag, Optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void markNestedOpsIgnored(Operation *op)
Recursively marks the nested operations under 'op' as ignored.
SmallVector< Operation * > createdOps
Ordered vector of all of the newly created operations during conversion.
void notifyRegionWasClonedBefore(iterator_range< Region::iterator > &blocks, Location origRegionLoc)
Notifies that the blocks of a region were cloned into another.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notifies that a pattern match failed for the given reason.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization, but were not directly repla...
SmallVector< BlockArgument, 4 > argReplacements
Ordered vector of any requested block argument replacements.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
void notifyCreatedBlock(Block *block)
Notifies that a block was created.
SmallVector< unsigned, 4 > operationsWithChangedResults
A vector of indices into replacements of operations that were replaced with values with different res...