MLIR  14.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/StringMap.h"
19 
20 namespace mlir {
21 
22 // Forward declarations.
23 class Block;
24 class ConversionPatternRewriter;
25 class MLIRContext;
26 class Operation;
27 class Type;
28 class Value;
29 
30 //===----------------------------------------------------------------------===//
31 // Type Conversion
32 //===----------------------------------------------------------------------===//
33 
34 /// Type conversion class. Specific conversions and materializations can be
35 /// registered using addConversion and addMaterialization, respectively.
37 public:
38  /// This class provides all of the information necessary to convert a type
39  /// signature.
41  public:
42  SignatureConversion(unsigned numOrigInputs)
43  : remappedInputs(numOrigInputs) {}
44 
45  /// This struct represents a range of new types or a single value that
46  /// remaps an existing signature input.
47  struct InputMapping {
48  size_t inputNo, size;
50  };
51 
52  /// Return the argument types for the new signature.
53  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
54 
55  /// Get the input mapping for the given argument.
56  Optional<InputMapping> getInputMapping(unsigned input) const {
57  return remappedInputs[input];
58  }
59 
60  //===------------------------------------------------------------------===//
61  // Conversion Hooks
62  //===------------------------------------------------------------------===//
63 
64  /// Remap an input of the original signature with a new set of types. The
65  /// new types are appended to the new signature conversion.
66  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
67 
68  /// Append new input types to the signature conversion, this should only be
69  /// used if the new types are not intended to remap an existing input.
70  void addInputs(ArrayRef<Type> types);
71 
72  /// Remap an input of the original signature to another `replacement`
73  /// value. This drops the original argument.
74  void remapInput(unsigned origInputNo, Value replacement);
75 
76  private:
77  /// Remap an input of the original signature with a range of types in the
78  /// new signature.
79  void remapInput(unsigned origInputNo, unsigned newInputNo,
80  unsigned newInputCount = 1);
81 
82  /// The remapping information for each of the original arguments.
83  SmallVector<Optional<InputMapping>, 4> remappedInputs;
84 
85  /// The set of new argument types.
86  SmallVector<Type, 4> argTypes;
87  };
88 
89  /// Register a conversion function. A conversion function must be convertible
90  /// to any of the following forms(where `T` is a class derived from `Type`:
91  /// * Optional<Type>(T)
92  /// - This form represents a 1-1 type conversion. It should return nullptr
93  /// or `llvm::None` to signify failure. If `llvm::None` is returned, the
94  /// converter is allowed to try another conversion function to perform
95  /// the conversion.
96  /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
97  /// - This form represents a 1-N type conversion. It should return
98  /// `failure` or `llvm::None` to signify a failed conversion. If the new
99  /// set of types is empty, the type is removed and any usages of the
100  /// existing value are expected to be removed during conversion. If
101  /// `llvm::None` is returned, the converter is allowed to try another
102  /// conversion function to perform the conversion.
103  /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
104  /// - This form represents a 1-N type conversion supporting recursive
105  /// types. The first two arguments and the return value are the same as
106  /// for the regular 1-N form. The third argument is contains is the
107  /// "call stack" of the recursive conversion: it contains the list of
108  /// types currently being converted, with the current type being the
109  /// last one. If it is present more than once in the list, the
110  /// conversion concerns a recursive type.
111  /// Note: When attempting to convert a type, e.g. via 'convertType', the
112  /// mostly recently added conversions will be invoked first.
113  template <typename FnT, typename T = typename llvm::function_traits<
114  std::decay_t<FnT>>::template arg_t<0>>
115  void addConversion(FnT &&callback) {
116  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
117  }
118 
119  /// Register a materialization function, which must be convertible to the
120  /// following form:
121  /// `Optional<Value>(OpBuilder &, T, ValueRange, Location)`,
122  /// where `T` is any subclass of `Type`. This function is responsible for
123  /// creating an operation, using the OpBuilder and Location provided, that
124  /// "casts" a range of values into a single value of the given type `T`. It
125  /// must return a Value of the converted type on success, an `llvm::None` if
126  /// it failed but other materialization can be attempted, and `nullptr` on
127  /// unrecoverable failure. It will only be called for (sub)types of `T`.
128  /// Materialization functions must be provided when a type conversion may
129  /// persist after the conversion has finished.
130  ///
131  /// This method registers a materialization that will be called when
132  /// converting an illegal block argument type, to a legal type.
133  template <typename FnT, typename T = typename llvm::function_traits<
134  std::decay_t<FnT>>::template arg_t<1>>
135  void addArgumentMaterialization(FnT &&callback) {
136  argumentMaterializations.emplace_back(
137  wrapMaterialization<T>(std::forward<FnT>(callback)));
138  }
139  /// This method registers a materialization that will be called when
140  /// converting a legal type to an illegal source type. This is used when
141  /// conversions to an illegal type must persist beyond the main conversion.
142  template <typename FnT, typename T = typename llvm::function_traits<
143  std::decay_t<FnT>>::template arg_t<1>>
144  void addSourceMaterialization(FnT &&callback) {
145  sourceMaterializations.emplace_back(
146  wrapMaterialization<T>(std::forward<FnT>(callback)));
147  }
148  /// This method registers a materialization that will be called when
149  /// converting type from an illegal, or source, type to a legal type.
150  template <typename FnT, typename T = typename llvm::function_traits<
151  std::decay_t<FnT>>::template arg_t<1>>
152  void addTargetMaterialization(FnT &&callback) {
153  targetMaterializations.emplace_back(
154  wrapMaterialization<T>(std::forward<FnT>(callback)));
155  }
156 
157  /// Convert the given type. This function should return failure if no valid
158  /// conversion exists, success otherwise. If the new set of types is empty,
159  /// the type is removed and any usages of the existing value are expected to
160  /// be removed during conversion.
162 
163  /// This hook simplifies defining 1-1 type conversions. This function returns
164  /// the type to convert to on success, and a null type on failure.
165  Type convertType(Type t);
166 
167  /// Convert the given set of types, filling 'results' as necessary. This
168  /// returns failure if the conversion of any of the types fails, success
169  /// otherwise.
171 
172  /// Return true if the given type is legal for this type converter, i.e. the
173  /// type converts to itself.
174  bool isLegal(Type type);
175  /// Return true if all of the given types are legal for this type converter.
176  template <typename RangeT>
179  bool>
180  isLegal(RangeT &&range) {
181  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
182  }
183  /// Return true if the given operation has legal operand and result types.
184  bool isLegal(Operation *op);
185 
186  /// Return true if the types of block arguments within the region are legal.
187  bool isLegal(Region *region);
188 
189  /// Return true if the inputs and outputs of the given function type are
190  /// legal.
191  bool isSignatureLegal(FunctionType ty);
192 
193  /// This method allows for converting a specific argument of a signature. It
194  /// takes as inputs the original argument input number, type.
195  /// On success, it populates 'result' with any new mappings.
197  SignatureConversion &result);
199  SignatureConversion &result,
200  unsigned origInputOffset = 0);
201 
202  /// This function converts the type signature of the given block, by invoking
203  /// 'convertSignatureArg' for each argument. This function should return a
204  /// valid conversion for the signature on success, None otherwise.
206 
207  /// Materialize a conversion from a set of types into one result type by
208  /// generating a cast sequence of some kind. See the respective
209  /// `add*Materialization` for more information on the context for these
210  /// methods.
212  Type resultType, ValueRange inputs) {
213  return materializeConversion(argumentMaterializations, builder, loc,
214  resultType, inputs);
215  }
217  Type resultType, ValueRange inputs) {
218  return materializeConversion(sourceMaterializations, builder, loc,
219  resultType, inputs);
220  }
222  Type resultType, ValueRange inputs) {
223  return materializeConversion(targetMaterializations, builder, loc,
224  resultType, inputs);
225  }
226 
227 private:
228  /// The signature of the callback used to convert a type. If the new set of
229  /// types is empty, the type is removed and any usages of the existing value
230  /// are expected to be removed during conversion.
231  using ConversionCallbackFn = std::function<Optional<LogicalResult>(
233 
234  /// The signature of the callback used to materialize a conversion.
235  using MaterializationCallbackFn =
236  std::function<Optional<Value>(OpBuilder &, Type, ValueRange, Location)>;
237 
238  /// Attempt to materialize a conversion using one of the provided
239  /// materialization functions.
240  Value materializeConversion(
242  OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
243 
244  /// Generate a wrapper for the given callback. This allows for accepting
245  /// different callback forms, that all compose into a single version.
246  /// With callback of form: `Optional<Type>(T)`
247  template <typename T, typename FnT>
249  wrapCallback(FnT &&callback) {
250  return wrapCallback<T>(
251  [callback = std::forward<FnT>(callback)](
252  T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
253  if (Optional<Type> resultOpt = callback(type)) {
254  bool wasSuccess = static_cast<bool>(resultOpt.getValue());
255  if (wasSuccess)
256  results.push_back(resultOpt.getValue());
257  return Optional<LogicalResult>(success(wasSuccess));
258  }
259  return Optional<LogicalResult>();
260  });
261  }
262  /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
263  /// &)`
264  template <typename T, typename FnT>
265  std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &>::value,
266  ConversionCallbackFn>
267  wrapCallback(FnT &&callback) {
268  return wrapCallback<T>(
269  [callback = std::forward<FnT>(callback)](
270  T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
271  return callback(type, results);
272  });
273  }
274  /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
275  /// &, ArrayRef<Type>)`.
276  template <typename T, typename FnT>
277  std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &,
279  ConversionCallbackFn>
280  wrapCallback(FnT &&callback) {
281  return [callback = std::forward<FnT>(callback)](
282  Type type, SmallVectorImpl<Type> &results,
284  T derivedType = type.dyn_cast<T>();
285  if (!derivedType)
286  return llvm::None;
287  return callback(derivedType, results, callStack);
288  };
289  }
290 
291  /// Register a type conversion.
292  void registerConversion(ConversionCallbackFn callback) {
293  conversions.emplace_back(std::move(callback));
294  cachedDirectConversions.clear();
295  cachedMultiConversions.clear();
296  }
297 
298  /// Generate a wrapper for the given materialization callback. The callback
299  /// may take any subclass of `Type` and the wrapper will check for the target
300  /// type to be of the expected class before calling the callback.
301  template <typename T, typename FnT>
302  MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
303  return [callback = std::forward<FnT>(callback)](
304  OpBuilder &builder, Type resultType, ValueRange inputs,
305  Location loc) -> Optional<Value> {
306  if (T derivedType = resultType.dyn_cast<T>())
307  return callback(builder, derivedType, inputs, loc);
308  return llvm::None;
309  };
310  }
311 
312  /// The set of registered conversion functions.
314 
315  /// The list of registered materialization functions.
316  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
317  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
318  SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
319 
320  /// A set of cached conversions to avoid recomputing in the common case.
321  /// Direct 1-1 conversions are the most common, so this cache stores the
322  /// successful 1-1 conversions as well as all failed conversions.
323  DenseMap<Type, Type> cachedDirectConversions;
324  /// This cache stores the successful 1->N conversions, where N != 1.
325  DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
326 
327  /// Stores the types that are being converted in the case when convertType
328  /// is being called recursively to convert nested types.
329  SmallVector<Type, 2> conversionCallStack;
330 };
331 
332 //===----------------------------------------------------------------------===//
333 // Conversion Patterns
334 //===----------------------------------------------------------------------===//
335 
336 /// Base class for the conversion patterns. This pattern class enables type
337 /// conversions, and other uses specific to the conversion framework. As such,
338 /// patterns of this type can only be used with the 'apply*' methods below.
340 public:
341  /// Hook for derived classes to implement rewriting. `op` is the (first)
342  /// operation matched by the pattern, `operands` is a list of the rewritten
343  /// operand values that are passed to `op`, `rewriter` can be used to emit the
344  /// new operations. This function should not fail. If some specific cases of
345  /// the operation are not supported, these cases should not be matched.
346  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
347  ConversionPatternRewriter &rewriter) const {
348  llvm_unreachable("unimplemented rewrite");
349  }
350 
351  /// Hook for derived classes to implement combined matching and rewriting.
352  virtual LogicalResult
354  ConversionPatternRewriter &rewriter) const {
355  if (failed(match(op)))
356  return failure();
357  rewrite(op, operands, rewriter);
358  return success();
359  }
360 
361  /// Attempt to match and rewrite the IR root at the specified operation.
362  LogicalResult matchAndRewrite(Operation *op,
363  PatternRewriter &rewriter) const final;
364 
365  /// Return the type converter held by this pattern, or nullptr if the pattern
366  /// does not require type conversion.
367  TypeConverter *getTypeConverter() const { return typeConverter; }
368 
369  template <typename ConverterTy>
371  ConverterTy *>
373  return static_cast<ConverterTy *>(typeConverter);
374  }
375 
376 protected:
377  /// See `RewritePattern::RewritePattern` for information on the other
378  /// available constructors.
379  using RewritePattern::RewritePattern;
380  /// Construct a conversion pattern with the given converter, and forward the
381  /// remaining arguments to RewritePattern.
382  template <typename... Args>
383  ConversionPattern(TypeConverter &typeConverter, Args &&...args)
384  : RewritePattern(std::forward<Args>(args)...),
385  typeConverter(&typeConverter) {}
386 
387 protected:
388  /// An optional type converter for use by this pattern.
389  TypeConverter *typeConverter = nullptr;
390 
391 private:
393 };
394 
395 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
396 /// matching and rewriting against an instance of a derived operation class as
397 /// opposed to a raw Operation.
398 template <typename SourceOp>
400 public:
401  using OpAdaptor = typename SourceOp::Adaptor;
402 
404  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
406  PatternBenefit benefit = 1)
407  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
408  context) {}
409 
410  /// Wrappers around the ConversionPattern methods that pass the derived op
411  /// type.
412  LogicalResult match(Operation *op) const final {
413  return match(cast<SourceOp>(op));
414  }
415  void rewrite(Operation *op, ArrayRef<Value> operands,
416  ConversionPatternRewriter &rewriter) const final {
417  rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
418  rewriter);
419  }
422  ConversionPatternRewriter &rewriter) const final {
423  return matchAndRewrite(cast<SourceOp>(op),
424  OpAdaptor(operands, op->getAttrDictionary()),
425  rewriter);
426  }
427 
428  /// Rewrite and Match methods that operate on the SourceOp type. These must be
429  /// overridden by the derived pattern class.
430  virtual LogicalResult match(SourceOp op) const {
431  llvm_unreachable("must override match or matchAndRewrite");
432  }
433  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
434  ConversionPatternRewriter &rewriter) const {
435  llvm_unreachable("must override matchAndRewrite or a rewrite method");
436  }
437  virtual LogicalResult
438  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
439  ConversionPatternRewriter &rewriter) const {
440  if (failed(match(op)))
441  return failure();
442  rewrite(op, adaptor, rewriter);
443  return success();
444  }
445 
446 private:
448 };
449 
450 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
451 /// allows for matching and rewriting against an instance of an OpInterface
452 /// class as opposed to a raw Operation.
453 template <typename SourceOp>
455 public:
458  SourceOp::getInterfaceID(), benefit, context) {}
460  MLIRContext *context, PatternBenefit benefit = 1)
462  SourceOp::getInterfaceID(), benefit, context) {}
463 
464  /// Wrappers around the ConversionPattern methods that pass the derived op
465  /// type.
466  void rewrite(Operation *op, ArrayRef<Value> operands,
467  ConversionPatternRewriter &rewriter) const final {
468  rewrite(cast<SourceOp>(op), operands, rewriter);
469  }
472  ConversionPatternRewriter &rewriter) const final {
473  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
474  }
475 
476  /// Rewrite and Match methods that operate on the SourceOp type. These must be
477  /// overridden by the derived pattern class.
478  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
479  ConversionPatternRewriter &rewriter) const {
480  llvm_unreachable("must override matchAndRewrite or a rewrite method");
481  }
482  virtual LogicalResult
483  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
484  ConversionPatternRewriter &rewriter) const {
485  if (failed(match(op)))
486  return failure();
487  rewrite(op, operands, rewriter);
488  return success();
489  }
490 
491 private:
493 };
494 
495 /// Add a pattern to the given pattern list to convert the signature of a
496 /// FunctionOpInterface op with the given type converter. This only supports
497 /// ops which use FunctionType to represent their type.
499  StringRef functionLikeOpName, RewritePatternSet &patterns,
500  TypeConverter &converter);
501 
502 template <typename FuncOpT>
504  RewritePatternSet &patterns, TypeConverter &converter) {
505  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
506  patterns, converter);
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // Conversion PatternRewriter
511 //===----------------------------------------------------------------------===//
512 
513 namespace detail {
514 struct ConversionPatternRewriterImpl;
515 } // namespace detail
516 
517 /// This class implements a pattern rewriter for use with ConversionPatterns. It
518 /// extends the base PatternRewriter and provides special conversion specific
519 /// hooks.
521 public:
522  explicit ConversionPatternRewriter(MLIRContext *ctx);
523  ~ConversionPatternRewriter() override;
524 
525  /// Apply a signature conversion to the entry block of the given region. This
526  /// replaces the entry block with a new block containing the updated
527  /// signature. The new entry block to the region is returned for convenience.
528  ///
529  /// If provided, `converter` will be used for any materializations.
530  Block *
531  applySignatureConversion(Region *region,
533  TypeConverter *converter = nullptr);
534 
535  /// Convert the types of block arguments within the given region. This
536  /// replaces each block with a new block containing the updated signature. The
537  /// entry block may have a special conversion if `entryConversion` is
538  /// provided. On success, the new entry block to the region is returned for
539  /// convenience. Otherwise, failure is returned.
540  FailureOr<Block *> convertRegionTypes(
541  Region *region, TypeConverter &converter,
542  TypeConverter::SignatureConversion *entryConversion = nullptr);
543 
544  /// Convert the types of block arguments within the given region except for
545  /// the entry region. This replaces each non-entry block with a new block
546  /// containing the updated signature.
547  ///
548  /// If special conversion behavior is needed for the non-entry blocks (for
549  /// example, we need to convert only a subset of a BB arguments), such
550  /// behavior can be specified in blockConversions.
551  LogicalResult convertNonEntryRegionTypes(
552  Region *region, TypeConverter &converter,
554 
555  /// Replace all the uses of the block argument `from` with value `to`.
556  void replaceUsesOfBlockArgument(BlockArgument from, Value to);
557 
558  /// Return the converted value of 'key' with a type defined by the type
559  /// converter of the currently executing pattern. Return nullptr in the case
560  /// of failure, the remapped value otherwise.
561  Value getRemappedValue(Value key);
562 
563  /// Return the converted values that replace 'keys' with types defined by the
564  /// type converter of the currently executing pattern. Returns failure if the
565  /// remap failed, success otherwise.
566  LogicalResult getRemappedValues(ValueRange keys,
567  SmallVectorImpl<Value> &results);
568 
569  //===--------------------------------------------------------------------===//
570  // PatternRewriter Hooks
571  //===--------------------------------------------------------------------===//
572 
573  /// PatternRewriter hook for replacing the results of an operation when the
574  /// given functor returns true.
575  void replaceOpWithIf(
576  Operation *op, ValueRange newValues, bool *allUsesReplaced,
577  llvm::unique_function<bool(OpOperand &) const> functor) override;
578 
579  /// PatternRewriter hook for replacing the results of an operation.
580  void replaceOp(Operation *op, ValueRange newValues) override;
582 
583  /// PatternRewriter hook for erasing a dead operation. The uses of this
584  /// operation *must* be made dead by the end of the conversion process,
585  /// otherwise an assert will be issued.
586  void eraseOp(Operation *op) override;
587 
588  /// PatternRewriter hook for erase all operations in a block. This is not yet
589  /// implemented for dialect conversion.
590  void eraseBlock(Block *block) override;
591 
592  /// PatternRewriter hook creating a new block.
593  void notifyBlockCreated(Block *block) override;
594 
595  /// PatternRewriter hook for splitting a block into two parts.
596  Block *splitBlock(Block *block, Block::iterator before) override;
597 
598  /// PatternRewriter hook for merging a block into another.
599  void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override;
600 
601  /// PatternRewriter hook for moving blocks out of a region.
602  void inlineRegionBefore(Region &region, Region &parent,
603  Region::iterator before) override;
605 
606  /// PatternRewriter hook for cloning blocks of one region into another. The
607  /// given region to clone *must* not have been modified as part of conversion
608  /// yet, i.e. it must be within an operation that is either in the process of
609  /// conversion, or has not yet been converted.
610  void cloneRegionBefore(Region &region, Region &parent,
611  Region::iterator before,
612  BlockAndValueMapping &mapping) override;
614 
615  /// PatternRewriter hook for inserting a new operation.
616  void notifyOperationInserted(Operation *op) override;
617 
618  /// PatternRewriter hook for updating the root operation in-place.
619  /// Note: These methods only track updates to the top-level operation itself,
620  /// and not nested regions. Updates to regions will still require notification
621  /// through other more specific hooks above.
622  void startRootUpdate(Operation *op) override;
623 
624  /// PatternRewriter hook for updating the root operation in-place.
625  void finalizeRootUpdate(Operation *op) override;
626 
627  /// PatternRewriter hook for updating the root operation in-place.
628  void cancelRootUpdate(Operation *op) override;
629 
630  /// PatternRewriter hook for notifying match failure reasons.
632  notifyMatchFailure(Operation *op,
633  function_ref<void(Diagnostic &)> reasonCallback) override;
635 
636  /// Return a reference to the internal implementation.
638 
639 private:
640  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
641 };
642 
643 //===----------------------------------------------------------------------===//
644 // ConversionTarget
645 //===----------------------------------------------------------------------===//
646 
647 /// This class describes a specific conversion target.
649 public:
650  /// This enumeration corresponds to the specific action to take when
651  /// considering an operation legal for this conversion target.
652  enum class LegalizationAction {
653  /// The target supports this operation.
654  Legal,
655 
656  /// This operation has dynamic legalization constraints that must be checked
657  /// by the target.
658  Dynamic,
659 
660  /// The target explicitly does not support this operation.
661  Illegal,
662  };
663 
664  /// A structure containing additional information describing a specific legal
665  /// operation instance.
666  struct LegalOpDetails {
667  /// A flag that indicates if this operation is 'recursively' legal. This
668  /// means that if an operation is legal, either statically or dynamically,
669  /// all of the operations nested within are also considered legal.
670  bool isRecursivelyLegal = false;
671  };
672 
673  /// The signature of the callback used to determine if an operation is
674  /// dynamically legal on the target.
675  using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
676 
677  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
678  virtual ~ConversionTarget() = default;
679 
680  //===--------------------------------------------------------------------===//
681  // Legality Registration
682  //===--------------------------------------------------------------------===//
683 
684  /// Register a legality action for the given operation.
685  void setOpAction(OperationName op, LegalizationAction action);
686  template <typename OpT>
688  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
689  }
690 
691  /// Register the given operations as legal.
692  template <typename OpT>
693  void addLegalOp() {
694  setOpAction<OpT>(LegalizationAction::Legal);
695  }
696  template <typename OpT, typename OpT2, typename... OpTs>
697  void addLegalOp() {
698  addLegalOp<OpT>();
699  addLegalOp<OpT2, OpTs...>();
700  }
701 
702  /// Register the given operation as dynamically legal and set the dynamic
703  /// legalization callback to the one provided.
704  template <typename OpT>
706  OperationName opName(OpT::getOperationName(), &ctx);
707  setOpAction(opName, LegalizationAction::Dynamic);
708  setLegalityCallback(opName, callback);
709  }
710  template <typename OpT, typename OpT2, typename... OpTs>
712  addDynamicallyLegalOp<OpT>(callback);
713  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
714  }
715  template <typename OpT, class Callable>
716  typename std::enable_if<
718  addDynamicallyLegalOp(Callable &&callback) {
719  addDynamicallyLegalOp<OpT>(
720  [=](Operation *op) { return callback(cast<OpT>(op)); });
721  }
722 
723  /// Register the given operation as illegal, i.e. this operation is known to
724  /// not be supported by this target.
725  template <typename OpT>
726  void addIllegalOp() {
727  setOpAction<OpT>(LegalizationAction::Illegal);
728  }
729  template <typename OpT, typename OpT2, typename... OpTs>
730  void addIllegalOp() {
731  addIllegalOp<OpT>();
732  addIllegalOp<OpT2, OpTs...>();
733  }
734 
735  /// Mark an operation, that *must* have either been set as `Legal` or
736  /// `DynamicallyLegal`, as being recursively legal. This means that in
737  /// addition to the operation itself, all of the operations nested within are
738  /// also considered legal. An optional dynamic legality callback may be
739  /// provided to mark subsets of legal instances as recursively legal.
740  template <typename OpT>
742  OperationName opName(OpT::getOperationName(), &ctx);
743  markOpRecursivelyLegal(opName, callback);
744  }
745  template <typename OpT, typename OpT2, typename... OpTs>
747  markOpRecursivelyLegal<OpT>(callback);
748  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
749  }
750  template <typename OpT, class Callable>
751  typename std::enable_if<
753  markOpRecursivelyLegal(Callable &&callback) {
754  markOpRecursivelyLegal<OpT>(
755  [=](Operation *op) { return callback(cast<OpT>(op)); });
756  }
757 
758  /// Register a legality action for the given dialects.
759  void setDialectAction(ArrayRef<StringRef> dialectNames,
760  LegalizationAction action);
761 
762  /// Register the operations of the given dialects as legal.
763  template <typename... Names>
764  void addLegalDialect(StringRef name, Names... names) {
765  SmallVector<StringRef, 2> dialectNames({name, names...});
766  setDialectAction(dialectNames, LegalizationAction::Legal);
767  }
768  template <typename... Args>
770  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
771  setDialectAction(dialectNames, LegalizationAction::Legal);
772  }
773 
774  /// Register the operations of the given dialects as dynamically legal, i.e.
775  /// requiring custom handling by the callback.
776  template <typename... Names>
778  StringRef name, Names... names) {
779  SmallVector<StringRef, 2> dialectNames({name, names...});
780  setDialectAction(dialectNames, LegalizationAction::Dynamic);
781  setLegalityCallback(dialectNames, callback);
782  }
783  template <typename... Args>
785  addDynamicallyLegalDialect(std::move(callback),
786  Args::getDialectNamespace()...);
787  }
788 
789  /// Register unknown operations as dynamically legal. For operations(and
790  /// dialects) that do not have a set legalization action, treat them as
791  /// dynamically legal and invoke the given callback.
793  setLegalityCallback(fn);
794  }
795 
796  /// Register the operations of the given dialects as illegal, i.e.
797  /// operations of this dialect are not supported by the target.
798  template <typename... Names>
799  void addIllegalDialect(StringRef name, Names... names) {
800  SmallVector<StringRef, 2> dialectNames({name, names...});
801  setDialectAction(dialectNames, LegalizationAction::Illegal);
802  }
803  template <typename... Args>
805  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
806  setDialectAction(dialectNames, LegalizationAction::Illegal);
807  }
808 
809  //===--------------------------------------------------------------------===//
810  // Legality Querying
811  //===--------------------------------------------------------------------===//
812 
813  /// Get the legality action for the given operation.
814  Optional<LegalizationAction> getOpAction(OperationName op) const;
815 
816  /// If the given operation instance is legal on this target, a structure
817  /// containing legality information is returned. If the operation is not
818  /// legal, None is returned. Also returns None is operation legality wasn't
819  /// registered by user or dynamic legality callbacks returned None.
820  ///
821  /// Note: Legality is actually a 4-state: Legal(recursive=true),
822  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
823  /// either as Legal or Illegal depending on context.
825 
826  /// Returns true is operation instance is illegal on this target. Returns
827  /// false if operation is legal, operation legality wasn't registered by user
828  /// or dynamic legality callbacks returned None.
829  bool isIllegal(Operation *op) const;
830 
831 private:
832  /// Set the dynamic legality callback for the given operation.
833  void setLegalityCallback(OperationName name,
834  const DynamicLegalityCallbackFn &callback);
835 
836  /// Set the dynamic legality callback for the given dialects.
837  void setLegalityCallback(ArrayRef<StringRef> dialects,
838  const DynamicLegalityCallbackFn &callback);
839 
840  /// Set the dynamic legality callback for the unknown ops.
841  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
842 
843  /// Set the recursive legality callback for the given operation and mark the
844  /// operation as recursively legal.
845  void markOpRecursivelyLegal(OperationName name,
846  const DynamicLegalityCallbackFn &callback);
847 
848  /// The set of information that configures the legalization of an operation.
849  struct LegalizationInfo {
850  /// The legality action this operation was given.
851  LegalizationAction action = LegalizationAction::Illegal;
852 
853  /// If some legal instances of this operation may also be recursively legal.
854  bool isRecursivelyLegal = false;
855 
856  /// The legality callback if this operation is dynamically legal.
857  DynamicLegalityCallbackFn legalityFn;
858  };
859 
860  /// Get the legalization information for the given operation.
861  Optional<LegalizationInfo> getOpInfo(OperationName op) const;
862 
863  /// A deterministic mapping of operation name and its respective legality
864  /// information.
865  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
866 
867  /// A set of legality callbacks for given operation names that are used to
868  /// check if an operation instance is recursively legal.
870 
871  /// A deterministic mapping of dialect name to the specific legality action to
872  /// take.
873  llvm::StringMap<LegalizationAction> legalDialects;
874 
875  /// A set of dynamic legality callbacks for given dialect names.
876  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
877 
878  /// An optional legality callback for unknown operations.
879  DynamicLegalityCallbackFn unknownLegalityFn;
880 
881  /// The current context this target applies to.
882  MLIRContext &ctx;
883 };
884 
885 //===----------------------------------------------------------------------===//
886 // Op Conversion Entry Points
887 //===----------------------------------------------------------------------===//
888 
889 /// Below we define several entry points for operation conversion. It is
890 /// important to note that the patterns provided to the conversion framework may
891 /// have additional constraints. See the `PatternRewriter Hooks` section of the
892 /// ConversionPatternRewriter, to see what additional constraints are imposed on
893 /// the use of the PatternRewriter.
894 
895 /// Apply a partial conversion on the given operations and all nested
896 /// operations. This method converts as many operations to the target as
897 /// possible, ignoring operations that failed to legalize. This method only
898 /// returns failure if there ops explicitly marked as illegal. If an
899 /// `unconvertedOps` set is provided, all operations that are found not to be
900 /// legalizable to the given `target` are placed within that set. (Note that if
901 /// there is an op explicitly marked as illegal, the conversion terminates and
902 /// the `unconvertedOps` set will not necessarily be complete.)
905  const FrozenRewritePatternSet &patterns,
906  DenseSet<Operation *> *unconvertedOps = nullptr);
909  const FrozenRewritePatternSet &patterns,
910  DenseSet<Operation *> *unconvertedOps = nullptr);
911 
912 /// Apply a complete conversion on the given operations, and all nested
913 /// operations. This method returns failure if the conversion of any operation
914 /// fails, or if there are unreachable blocks in any of the regions nested
915 /// within 'ops'.
917  ConversionTarget &target,
918  const FrozenRewritePatternSet &patterns);
920  const FrozenRewritePatternSet &patterns);
921 
922 /// Apply an analysis conversion on the given operations, and all nested
923 /// operations. This method analyzes which operations would be successfully
924 /// converted to the target if a conversion was applied. All operations that
925 /// were found to be legalizable to the given 'target' are placed within the
926 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
927 /// operations on success and only pre-existing operations are added to the set.
928 /// This method only returns failure if there are unreachable blocks in any of
929 /// the regions nested within 'ops'. There's an additional argument
930 /// `notifyCallback` which is used for collecting match failure diagnostics
931 /// generated during the conversion. Diagnostics are only reported to this
932 /// callback may only be available in debug mode.
935  const FrozenRewritePatternSet &patterns,
936  DenseSet<Operation *> &convertedOps,
937  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
939  Operation *op, ConversionTarget &target,
940  const FrozenRewritePatternSet &patterns,
941  DenseSet<Operation *> &convertedOps,
942  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
943 } // namespace mlir
944 
945 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Include the generated interface declarations.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling &#39;results&#39; as necessary.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:162
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
This class represents a frozen set of patterns that can be processed by a pattern applicator...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Definition: Block.h:29
Base class for the conversion patterns.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
BlockListType::iterator iterator
Definition: Region.h:52
typename async::YieldOp ::Adaptor OpAdaptor
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
ConversionPattern(TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
void setOpAction(LegalizationAction action)
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal...
This struct represents a range of new types or a single value that remaps an existing signature input...
A structure containing additional information describing a specific legal operation instance...
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: Optional<Value>...
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
ConversionTarget(MLIRContext &ctx)
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, ConverterTy * > getTypeConverter() const
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
std::function< Optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target...
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result)
This method allows for converting a specific argument of a signature.
LogicalResult applyFullConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
OpListType::iterator iterator
Definition: Block.h:131
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
U dyn_cast() const
Definition: Types.h:244
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
std::enable_if< !llvm::is_invocable< Callable, Operation * >::value >::type markOpRecursivelyLegal(Callable &&callback)
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
LogicalResult applyAnalysisConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:195
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
This class represents an argument of a Block.
Definition: Value.h:298
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
void addIllegalOp()
Register the given operation as illegal, i.e.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::enable_if< !llvm::is_invocable< Callable, Operation * >::value >::type addDynamicallyLegalOp(Callable &&callback)
void addConversion(FnT &&callback)
Register a conversion function.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Type conversion class.
bool isSignatureLegal(FunctionType ty)
Return true if the inputs and outputs of the given function type are legal.
Optional< SignatureConversion > convertBlockSignature(Block *block)
This function converts the type signature of the given block, by invoking &#39;convertSignatureArg&#39; for e...
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
void addLegalOp()
Register the given operations as legal.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class represents an operand of an operation.
Definition: Value.h:249
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0)
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
This class describes a specific conversion target.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
OpInterfaceConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range)
Return true if all of the given types are legal for this type converter.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.