MLIR  19.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  argumentMaterializations(other.argumentMaterializations),
49  sourceMaterializations(other.sourceMaterializations),
50  targetMaterializations(other.targetMaterializations),
51  typeAttributeConversions(other.typeAttributeConversions) {}
53  conversions = other.conversions;
54  argumentMaterializations = other.argumentMaterializations;
55  sourceMaterializations = other.sourceMaterializations;
56  targetMaterializations = other.targetMaterializations;
57  typeAttributeConversions = other.typeAttributeConversions;
58  return *this;
59  }
60 
61  /// This class provides all of the information necessary to convert a type
62  /// signature.
64  public:
65  SignatureConversion(unsigned numOrigInputs)
66  : remappedInputs(numOrigInputs) {}
67 
68  /// This struct represents a range of new types or a single value that
69  /// remaps an existing signature input.
70  struct InputMapping {
71  size_t inputNo, size;
73  };
74 
75  /// Return the argument types for the new signature.
76  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
77 
78  /// Get the input mapping for the given argument.
79  std::optional<InputMapping> getInputMapping(unsigned input) const {
80  return remappedInputs[input];
81  }
82 
83  //===------------------------------------------------------------------===//
84  // Conversion Hooks
85  //===------------------------------------------------------------------===//
86 
87  /// Remap an input of the original signature with a new set of types. The
88  /// new types are appended to the new signature conversion.
89  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
90 
91  /// Append new input types to the signature conversion, this should only be
92  /// used if the new types are not intended to remap an existing input.
93  void addInputs(ArrayRef<Type> types);
94 
95  /// Remap an input of the original signature to another `replacement`
96  /// value. This drops the original argument.
97  void remapInput(unsigned origInputNo, Value replacement);
98 
99  private:
100  /// Remap an input of the original signature with a range of types in the
101  /// new signature.
102  void remapInput(unsigned origInputNo, unsigned newInputNo,
103  unsigned newInputCount = 1);
104 
105  /// The remapping information for each of the original arguments.
106  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
107 
108  /// The set of new argument types.
109  SmallVector<Type, 4> argTypes;
110  };
111 
112  /// The general result of a type attribute conversion callback, allowing
113  /// for early termination. The default constructor creates the na case.
115  public:
116  constexpr AttributeConversionResult() : impl() {}
117  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
118 
120  static AttributeConversionResult na();
122 
123  bool hasResult() const;
124  bool isNa() const;
125  bool isAbort() const;
126 
127  Attribute getResult() const;
128 
129  private:
130  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
131 
132  llvm::PointerIntPair<Attribute, 2> impl;
133  // Note that na is 0 so that we can use PointerIntPair's default
134  // constructor.
135  static constexpr unsigned naTag = 0;
136  static constexpr unsigned resultTag = 1;
137  static constexpr unsigned abortTag = 2;
138  };
139 
140  /// Register a conversion function. A conversion function must be convertible
141  /// to any of the following forms(where `T` is a class derived from `Type`:
142  /// * std::optional<Type>(T)
143  /// - This form represents a 1-1 type conversion. It should return nullptr
144  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
145  /// the converter is allowed to try another conversion function to
146  /// perform the conversion.
147  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
148  /// - This form represents a 1-N type conversion. It should return
149  /// `failure` or `std::nullopt` to signify a failed conversion. If the
150  /// new set of types is empty, the type is removed and any usages of the
151  /// existing value are expected to be removed during conversion. If
152  /// `std::nullopt` is returned, the converter is allowed to try another
153  /// conversion function to perform the conversion.
154  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
155  /// ArrayRef<Type>)
156  /// - This form represents a 1-N type conversion supporting recursive
157  /// types. The first two arguments and the return value are the same as
158  /// for the regular 1-N form. The third argument is contains is the
159  /// "call stack" of the recursive conversion: it contains the list of
160  /// types currently being converted, with the current type being the
161  /// last one. If it is present more than once in the list, the
162  /// conversion concerns a recursive type.
163  /// Note: When attempting to convert a type, e.g. via 'convertType', the
164  /// mostly recently added conversions will be invoked first.
165  template <typename FnT, typename T = typename llvm::function_traits<
166  std::decay_t<FnT>>::template arg_t<0>>
167  void addConversion(FnT &&callback) {
168  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
169  }
170 
171  /// Register a materialization function, which must be convertible to the
172  /// following form:
173  /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
174  /// where `T` is any subclass of `Type`. This function is responsible for
175  /// creating an operation, using the OpBuilder and Location provided, that
176  /// "casts" a range of values into a single value of the given type `T`. It
177  /// must return a Value of the converted type on success, an `std::nullopt` if
178  /// it failed but other materialization can be attempted, and `nullptr` on
179  /// unrecoverable failure. It will only be called for (sub)types of `T`.
180  /// Materialization functions must be provided when a type conversion may
181  /// persist after the conversion has finished.
182  ///
183  /// This method registers a materialization that will be called when
184  /// converting an illegal block argument type, to a legal type.
185  template <typename FnT, typename T = typename llvm::function_traits<
186  std::decay_t<FnT>>::template arg_t<1>>
187  void addArgumentMaterialization(FnT &&callback) {
188  argumentMaterializations.emplace_back(
189  wrapMaterialization<T>(std::forward<FnT>(callback)));
190  }
191  /// This method registers a materialization that will be called when
192  /// converting a legal type to an illegal source type. This is used when
193  /// conversions to an illegal type must persist beyond the main conversion.
194  template <typename FnT, typename T = typename llvm::function_traits<
195  std::decay_t<FnT>>::template arg_t<1>>
196  void addSourceMaterialization(FnT &&callback) {
197  sourceMaterializations.emplace_back(
198  wrapMaterialization<T>(std::forward<FnT>(callback)));
199  }
200  /// This method registers a materialization that will be called when
201  /// converting type from an illegal, or source, type to a legal type.
202  template <typename FnT, typename T = typename llvm::function_traits<
203  std::decay_t<FnT>>::template arg_t<1>>
204  void addTargetMaterialization(FnT &&callback) {
205  targetMaterializations.emplace_back(
206  wrapMaterialization<T>(std::forward<FnT>(callback)));
207  }
208 
209  /// Register a conversion function for attributes within types. Type
210  /// converters may call this function in order to allow hoking into the
211  /// translation of attributes that exist within types. For example, a type
212  /// converter for the `memref` type could use these conversions to convert
213  /// memory spaces or layouts in an extensible way.
214  ///
215  /// The conversion functions take a non-null Type or subclass of Type and a
216  /// non-null Attribute (or subclass of Attribute), and returns a
217  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
218  /// which may be `nullptr`, representing the conversion's success,
219  /// `AttributeConversionResult::na()` (the default empty value), indicating
220  /// that the conversion function did not apply and that further conversion
221  /// functions should be checked, or `AttributeConversionResult::abort()`
222  /// indicating that the conversion process should be aborted.
223  ///
224  /// Registered conversion functions are callled in the reverse of the order in
225  /// which they were registered.
226  template <
227  typename FnT,
228  typename T =
229  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
230  typename A =
231  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
232  void addTypeAttributeConversion(FnT &&callback) {
233  registerTypeAttributeConversion(
234  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
235  }
236 
237  /// Convert the given type. This function should return failure if no valid
238  /// conversion exists, success otherwise. If the new set of types is empty,
239  /// the type is removed and any usages of the existing value are expected to
240  /// be removed during conversion.
242 
243  /// This hook simplifies defining 1-1 type conversions. This function returns
244  /// the type to convert to on success, and a null type on failure.
245  Type convertType(Type t) const;
246 
247  /// Attempts a 1-1 type conversion, expecting the result type to be
248  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
249  /// and a null type on conversion or cast failure.
250  template <typename TargetType> TargetType convertType(Type t) const {
251  return dyn_cast_or_null<TargetType>(convertType(t));
252  }
253 
254  /// Convert the given set of types, filling 'results' as necessary. This
255  /// returns failure if the conversion of any of the types fails, success
256  /// otherwise.
258  SmallVectorImpl<Type> &results) const;
259 
260  /// Return true if the given type is legal for this type converter, i.e. the
261  /// type converts to itself.
262  bool isLegal(Type type) const;
263 
264  /// Return true if all of the given types are legal for this type converter.
265  template <typename RangeT>
266  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
267  !std::is_convertible<RangeT, Operation *>::value,
268  bool>
269  isLegal(RangeT &&range) const {
270  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
271  }
272  /// Return true if the given operation has legal operand and result types.
273  bool isLegal(Operation *op) const;
274 
275  /// Return true if the types of block arguments within the region are legal.
276  bool isLegal(Region *region) const;
277 
278  /// Return true if the inputs and outputs of the given function type are
279  /// legal.
280  bool isSignatureLegal(FunctionType ty) const;
281 
282  /// This method allows for converting a specific argument of a signature. It
283  /// takes as inputs the original argument input number, type.
284  /// On success, it populates 'result' with any new mappings.
285  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
286  SignatureConversion &result) const;
288  SignatureConversion &result,
289  unsigned origInputOffset = 0) const;
290 
291  /// This function converts the type signature of the given block, by invoking
292  /// 'convertSignatureArg' for each argument. This function should return a
293  /// valid conversion for the signature on success, std::nullopt otherwise.
294  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
295 
296  /// Materialize a conversion from a set of types into one result type by
297  /// generating a cast sequence of some kind. See the respective
298  /// `add*Materialization` for more information on the context for these
299  /// methods.
301  Type resultType,
302  ValueRange inputs) const {
303  return materializeConversion(argumentMaterializations, builder, loc,
304  resultType, inputs);
305  }
307  Type resultType, ValueRange inputs) const {
308  return materializeConversion(sourceMaterializations, builder, loc,
309  resultType, inputs);
310  }
312  Type resultType, ValueRange inputs) const {
313  return materializeConversion(targetMaterializations, builder, loc,
314  resultType, inputs);
315  }
316 
317  /// Convert an attribute present `attr` from within the type `type` using
318  /// the registered conversion functions. If no applicable conversion has been
319  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
320  /// is a valid return value for this function.
321  std::optional<Attribute> convertTypeAttribute(Type type,
322  Attribute attr) const;
323 
324 private:
325  /// The signature of the callback used to convert a type. If the new set of
326  /// types is empty, the type is removed and any usages of the existing value
327  /// are expected to be removed during conversion.
328  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
330 
331  /// The signature of the callback used to materialize a conversion.
332  using MaterializationCallbackFn = std::function<std::optional<Value>(
334 
335  /// The signature of the callback used to convert a type attribute.
336  using TypeAttributeConversionCallbackFn =
337  std::function<AttributeConversionResult(Type, Attribute)>;
338 
339  /// Attempt to materialize a conversion using one of the provided
340  /// materialization functions.
341  Value
342  materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
343  OpBuilder &builder, Location loc, Type resultType,
344  ValueRange inputs) const;
345 
346  /// Generate a wrapper for the given callback. This allows for accepting
347  /// different callback forms, that all compose into a single version.
348  /// With callback of form: `std::optional<Type>(T)`
349  template <typename T, typename FnT>
350  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
351  wrapCallback(FnT &&callback) const {
352  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
353  T type, SmallVectorImpl<Type> &results) {
354  if (std::optional<Type> resultOpt = callback(type)) {
355  bool wasSuccess = static_cast<bool>(*resultOpt);
356  if (wasSuccess)
357  results.push_back(*resultOpt);
358  return std::optional<LogicalResult>(success(wasSuccess));
359  }
360  return std::optional<LogicalResult>();
361  });
362  }
363  /// With callback of form: `std::optional<LogicalResult>(
364  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
365  template <typename T, typename FnT>
366  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
367  ConversionCallbackFn>
368  wrapCallback(FnT &&callback) const {
369  return [callback = std::forward<FnT>(callback)](
370  Type type,
371  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
372  T derivedType = dyn_cast<T>(type);
373  if (!derivedType)
374  return std::nullopt;
375  return callback(derivedType, results);
376  };
377  }
378 
379  /// Register a type conversion.
380  void registerConversion(ConversionCallbackFn callback) {
381  conversions.emplace_back(std::move(callback));
382  cachedDirectConversions.clear();
383  cachedMultiConversions.clear();
384  }
385 
386  /// Generate a wrapper for the given materialization callback. The callback
387  /// may take any subclass of `Type` and the wrapper will check for the target
388  /// type to be of the expected class before calling the callback.
389  template <typename T, typename FnT>
390  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
391  return [callback = std::forward<FnT>(callback)](
392  OpBuilder &builder, Type resultType, ValueRange inputs,
393  Location loc) -> std::optional<Value> {
394  if (T derivedType = dyn_cast<T>(resultType))
395  return callback(builder, derivedType, inputs, loc);
396  return std::nullopt;
397  };
398  }
399 
400  /// Generate a wrapper for the given memory space conversion callback. The
401  /// callback may take any subclass of `Attribute` and the wrapper will check
402  /// for the target attribute to be of the expected class before calling the
403  /// callback.
404  template <typename T, typename A, typename FnT>
405  TypeAttributeConversionCallbackFn
406  wrapTypeAttributeConversion(FnT &&callback) const {
407  return [callback = std::forward<FnT>(callback)](
408  Type type, Attribute attr) -> AttributeConversionResult {
409  if (T derivedType = dyn_cast<T>(type)) {
410  if (A derivedAttr = dyn_cast_or_null<A>(attr))
411  return callback(derivedType, derivedAttr);
412  }
414  };
415  }
416 
417  /// Register a memory space conversion, clearing caches.
418  void
419  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
420  typeAttributeConversions.emplace_back(std::move(callback));
421  // Clear type conversions in case a memory space is lingering inside.
422  cachedDirectConversions.clear();
423  cachedMultiConversions.clear();
424  }
425 
426  /// The set of registered conversion functions.
427  SmallVector<ConversionCallbackFn, 4> conversions;
428 
429  /// The list of registered materialization functions.
430  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
431  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
432  SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
433 
434  /// The list of registered type attribute conversion functions.
435  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
436 
437  /// A set of cached conversions to avoid recomputing in the common case.
438  /// Direct 1-1 conversions are the most common, so this cache stores the
439  /// successful 1-1 conversions as well as all failed conversions.
440  mutable DenseMap<Type, Type> cachedDirectConversions;
441  /// This cache stores the successful 1->N conversions, where N != 1.
442  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
443  /// A mutex used for cache access
444  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
445 };
446 
447 //===----------------------------------------------------------------------===//
448 // Conversion Patterns
449 //===----------------------------------------------------------------------===//
450 
451 /// Base class for the conversion patterns. This pattern class enables type
452 /// conversions, and other uses specific to the conversion framework. As such,
453 /// patterns of this type can only be used with the 'apply*' methods below.
455 public:
456  /// Hook for derived classes to implement rewriting. `op` is the (first)
457  /// operation matched by the pattern, `operands` is a list of the rewritten
458  /// operand values that are passed to `op`, `rewriter` can be used to emit the
459  /// new operations. This function should not fail. If some specific cases of
460  /// the operation are not supported, these cases should not be matched.
461  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
462  ConversionPatternRewriter &rewriter) const {
463  llvm_unreachable("unimplemented rewrite");
464  }
465 
466  /// Hook for derived classes to implement combined matching and rewriting.
467  virtual LogicalResult
469  ConversionPatternRewriter &rewriter) const {
470  if (failed(match(op)))
471  return failure();
472  rewrite(op, operands, rewriter);
473  return success();
474  }
475 
476  /// Attempt to match and rewrite the IR root at the specified operation.
478  PatternRewriter &rewriter) const final;
479 
480  /// Return the type converter held by this pattern, or nullptr if the pattern
481  /// does not require type conversion.
482  const TypeConverter *getTypeConverter() const { return typeConverter; }
483 
484  template <typename ConverterTy>
485  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
486  const ConverterTy *>
488  return static_cast<const ConverterTy *>(typeConverter);
489  }
490 
491 protected:
492  /// See `RewritePattern::RewritePattern` for information on the other
493  /// available constructors.
494  using RewritePattern::RewritePattern;
495  /// Construct a conversion pattern with the given converter, and forward the
496  /// remaining arguments to RewritePattern.
497  template <typename... Args>
499  : RewritePattern(std::forward<Args>(args)...),
501 
502 protected:
503  /// An optional type converter for use by this pattern.
504  const TypeConverter *typeConverter = nullptr;
505 
506 private:
508 };
509 
510 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
511 /// matching and rewriting against an instance of a derived operation class as
512 /// opposed to a raw Operation.
513 template <typename SourceOp>
515 public:
516  using OpAdaptor = typename SourceOp::Adaptor;
517 
519  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
521  PatternBenefit benefit = 1)
522  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
523  context) {}
524 
525  /// Wrappers around the ConversionPattern methods that pass the derived op
526  /// type.
527  LogicalResult match(Operation *op) const final {
528  return match(cast<SourceOp>(op));
529  }
530  void rewrite(Operation *op, ArrayRef<Value> operands,
531  ConversionPatternRewriter &rewriter) const final {
532  auto sourceOp = cast<SourceOp>(op);
533  rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
534  }
537  ConversionPatternRewriter &rewriter) const final {
538  auto sourceOp = cast<SourceOp>(op);
539  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
540  }
541 
542  /// Rewrite and Match methods that operate on the SourceOp type. These must be
543  /// overridden by the derived pattern class.
544  virtual LogicalResult match(SourceOp op) const {
545  llvm_unreachable("must override match or matchAndRewrite");
546  }
547  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
548  ConversionPatternRewriter &rewriter) const {
549  llvm_unreachable("must override matchAndRewrite or a rewrite method");
550  }
551  virtual LogicalResult
552  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
553  ConversionPatternRewriter &rewriter) const {
554  if (failed(match(op)))
555  return failure();
556  rewrite(op, adaptor, rewriter);
557  return success();
558  }
559 
560 private:
562 };
563 
564 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
565 /// allows for matching and rewriting against an instance of an OpInterface
566 /// class as opposed to a raw Operation.
567 template <typename SourceOp>
569 public:
572  SourceOp::getInterfaceID(), benefit, context) {}
574  MLIRContext *context, PatternBenefit benefit = 1)
576  SourceOp::getInterfaceID(), benefit, context) {}
577 
578  /// Wrappers around the ConversionPattern methods that pass the derived op
579  /// type.
580  void rewrite(Operation *op, ArrayRef<Value> operands,
581  ConversionPatternRewriter &rewriter) const final {
582  rewrite(cast<SourceOp>(op), operands, rewriter);
583  }
586  ConversionPatternRewriter &rewriter) const final {
587  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
588  }
589 
590  /// Rewrite and Match methods that operate on the SourceOp type. These must be
591  /// overridden by the derived pattern class.
592  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
593  ConversionPatternRewriter &rewriter) const {
594  llvm_unreachable("must override matchAndRewrite or a rewrite method");
595  }
596  virtual LogicalResult
597  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
598  ConversionPatternRewriter &rewriter) const {
599  if (failed(match(op)))
600  return failure();
601  rewrite(op, operands, rewriter);
602  return success();
603  }
604 
605 private:
607 };
608 
609 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
610 /// for matching and rewriting against instances of an operation that possess a
611 /// given trait.
612 template <template <typename> class TraitType>
614 public:
617  TypeID::get<TraitType>(), benefit, context) {}
619  MLIRContext *context, PatternBenefit benefit = 1)
621  TypeID::get<TraitType>(), benefit, context) {}
622 };
623 
624 /// Generic utility to convert op result types according to type converter
625 /// without knowing exact op type.
626 /// Clones existing op with new result types and returns it.
627 FailureOr<Operation *>
628 convertOpResultTypes(Operation *op, ValueRange operands,
629  const TypeConverter &converter,
630  ConversionPatternRewriter &rewriter);
631 
632 /// Add a pattern to the given pattern list to convert the signature of a
633 /// FunctionOpInterface op with the given type converter. This only supports
634 /// ops which use FunctionType to represent their type.
636  StringRef functionLikeOpName, RewritePatternSet &patterns,
637  const TypeConverter &converter);
638 
639 template <typename FuncOpT>
641  RewritePatternSet &patterns, const TypeConverter &converter) {
642  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
643  patterns, converter);
644 }
645 
647  RewritePatternSet &patterns, const TypeConverter &converter);
648 
649 //===----------------------------------------------------------------------===//
650 // Conversion PatternRewriter
651 //===----------------------------------------------------------------------===//
652 
653 namespace detail {
654 struct ConversionPatternRewriterImpl;
655 } // namespace detail
656 
657 /// This class implements a pattern rewriter for use with ConversionPatterns. It
658 /// extends the base PatternRewriter and provides special conversion specific
659 /// hooks.
661 public:
663 
664  /// Apply a signature conversion to the entry block of the given region. This
665  /// replaces the entry block with a new block containing the updated
666  /// signature. The new entry block to the region is returned for convenience.
667  /// If no block argument types are changing, the entry original block will be
668  /// left in place and returned.
669  ///
670  /// If provided, `converter` will be used for any materializations.
671  Block *
674  const TypeConverter *converter = nullptr);
675 
676  /// Convert the types of block arguments within the given region. This
677  /// replaces each block with a new block containing the updated signature. If
678  /// an updated signature would match the current signature, the respective
679  /// block is left in place as is.
680  ///
681  /// The entry block may have a special conversion if `entryConversion` is
682  /// provided. On success, the new entry block to the region is returned for
683  /// convenience. Otherwise, failure is returned.
685  Region *region, const TypeConverter &converter,
686  TypeConverter::SignatureConversion *entryConversion = nullptr);
687 
688  /// Convert the types of block arguments within the given region except for
689  /// the entry region. This replaces each non-entry block with a new block
690  /// containing the updated signature. If an updated signature would match the
691  /// current signature, the respective block is left in place as is.
692  ///
693  /// If special conversion behavior is needed for the non-entry blocks (for
694  /// example, we need to convert only a subset of a BB arguments), such
695  /// behavior can be specified in blockConversions.
697  Region *region, const TypeConverter &converter,
699 
700  /// Replace all the uses of the block argument `from` with value `to`.
702 
703  /// Return the converted value of 'key' with a type defined by the type
704  /// converter of the currently executing pattern. Return nullptr in the case
705  /// of failure, the remapped value otherwise.
707 
708  /// Return the converted values that replace 'keys' with types defined by the
709  /// type converter of the currently executing pattern. Returns failure if the
710  /// remap failed, success otherwise.
712  SmallVectorImpl<Value> &results);
713 
714  //===--------------------------------------------------------------------===//
715  // PatternRewriter Hooks
716  //===--------------------------------------------------------------------===//
717 
718  /// Indicate that the conversion rewriter can recover from rewrite failure.
719  /// Recovery is supported via rollback, allowing for continued processing of
720  /// patterns even if a failure is encountered during the rewrite step.
721  bool canRecoverFromRewriteFailure() const override { return true; }
722 
723  /// PatternRewriter hook for replacing an operation.
724  void replaceOp(Operation *op, ValueRange newValues) override;
725 
726  /// PatternRewriter hook for replacing an operation.
727  void replaceOp(Operation *op, Operation *newOp) override;
728 
729  /// PatternRewriter hook for erasing a dead operation. The uses of this
730  /// operation *must* be made dead by the end of the conversion process,
731  /// otherwise an assert will be issued.
732  void eraseOp(Operation *op) override;
733 
734  /// PatternRewriter hook for erase all operations in a block. This is not yet
735  /// implemented for dialect conversion.
736  void eraseBlock(Block *block) override;
737 
738  /// PatternRewriter hook for inlining the ops of a block into another block.
739  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
740  ValueRange argValues = std::nullopt) override;
742 
743  /// PatternRewriter hook for updating the given operation in-place.
744  /// Note: These methods only track updates to the given operation itself,
745  /// and not nested regions. Updates to regions will still require notification
746  /// through other more specific hooks above.
747  void startOpModification(Operation *op) override;
748 
749  /// PatternRewriter hook for updating the given operation in-place.
750  void finalizeOpModification(Operation *op) override;
751 
752  /// PatternRewriter hook for updating the given operation in-place.
753  void cancelOpModification(Operation *op) override;
754 
755  /// Return a reference to the internal implementation.
757 
758 private:
759  // Allow OperationConverter to construct new rewriters.
760  friend struct OperationConverter;
761 
762  /// Conversion pattern rewriters must not be used outside of dialect
763  /// conversions. They apply some IR rewrites in a delayed fashion and could
764  /// bring the IR into an inconsistent state when used standalone.
766  const ConversionConfig &config);
767 
768  // Hide unsupported pattern rewriter API.
770 
771  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
772 };
773 
774 //===----------------------------------------------------------------------===//
775 // ConversionTarget
776 //===----------------------------------------------------------------------===//
777 
778 /// This class describes a specific conversion target.
780 public:
781  /// This enumeration corresponds to the specific action to take when
782  /// considering an operation legal for this conversion target.
783  enum class LegalizationAction {
784  /// The target supports this operation.
785  Legal,
786 
787  /// This operation has dynamic legalization constraints that must be checked
788  /// by the target.
789  Dynamic,
790 
791  /// The target explicitly does not support this operation.
792  Illegal,
793  };
794 
795  /// A structure containing additional information describing a specific legal
796  /// operation instance.
797  struct LegalOpDetails {
798  /// A flag that indicates if this operation is 'recursively' legal. This
799  /// means that if an operation is legal, either statically or dynamically,
800  /// all of the operations nested within are also considered legal.
801  bool isRecursivelyLegal = false;
802  };
803 
804  /// The signature of the callback used to determine if an operation is
805  /// dynamically legal on the target.
807  std::function<std::optional<bool>(Operation *)>;
808 
809  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
810  virtual ~ConversionTarget() = default;
811 
812  //===--------------------------------------------------------------------===//
813  // Legality Registration
814  //===--------------------------------------------------------------------===//
815 
816  /// Register a legality action for the given operation.
818  template <typename OpT>
820  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
821  }
822 
823  /// Register the given operations as legal.
826  }
827  template <typename OpT>
828  void addLegalOp() {
829  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
830  }
831  template <typename OpT, typename OpT2, typename... OpTs>
832  void addLegalOp() {
833  addLegalOp<OpT>();
834  addLegalOp<OpT2, OpTs...>();
835  }
836 
837  /// Register the given operation as dynamically legal and set the dynamic
838  /// legalization callback to the one provided.
840  const DynamicLegalityCallbackFn &callback) {
842  setLegalityCallback(op, callback);
843  }
844  template <typename OpT>
846  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
847  callback);
848  }
849  template <typename OpT, typename OpT2, typename... OpTs>
851  addDynamicallyLegalOp<OpT>(callback);
852  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
853  }
854  template <typename OpT, class Callable>
855  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
856  addDynamicallyLegalOp(Callable &&callback) {
857  addDynamicallyLegalOp<OpT>(
858  [=](Operation *op) { return callback(cast<OpT>(op)); });
859  }
860 
861  /// Register the given operation as illegal, i.e. this operation is known to
862  /// not be supported by this target.
865  }
866  template <typename OpT>
867  void addIllegalOp() {
868  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
869  }
870  template <typename OpT, typename OpT2, typename... OpTs>
871  void addIllegalOp() {
872  addIllegalOp<OpT>();
873  addIllegalOp<OpT2, OpTs...>();
874  }
875 
876  /// Mark an operation, that *must* have either been set as `Legal` or
877  /// `DynamicallyLegal`, as being recursively legal. This means that in
878  /// addition to the operation itself, all of the operations nested within are
879  /// also considered legal. An optional dynamic legality callback may be
880  /// provided to mark subsets of legal instances as recursively legal.
882  const DynamicLegalityCallbackFn &callback);
883  template <typename OpT>
885  OperationName opName(OpT::getOperationName(), &ctx);
886  markOpRecursivelyLegal(opName, callback);
887  }
888  template <typename OpT, typename OpT2, typename... OpTs>
890  markOpRecursivelyLegal<OpT>(callback);
891  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
892  }
893  template <typename OpT, class Callable>
894  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
895  markOpRecursivelyLegal(Callable &&callback) {
896  markOpRecursivelyLegal<OpT>(
897  [=](Operation *op) { return callback(cast<OpT>(op)); });
898  }
899 
900  /// Register a legality action for the given dialects.
901  void setDialectAction(ArrayRef<StringRef> dialectNames,
902  LegalizationAction action);
903 
904  /// Register the operations of the given dialects as legal.
905  template <typename... Names>
906  void addLegalDialect(StringRef name, Names... names) {
907  SmallVector<StringRef, 2> dialectNames({name, names...});
909  }
910  template <typename... Args>
912  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
914  }
915 
916  /// Register the operations of the given dialects as dynamically legal, i.e.
917  /// requiring custom handling by the callback.
918  template <typename... Names>
920  StringRef name, Names... names) {
921  SmallVector<StringRef, 2> dialectNames({name, names...});
923  setLegalityCallback(dialectNames, callback);
924  }
925  template <typename... Args>
927  addDynamicallyLegalDialect(std::move(callback),
928  Args::getDialectNamespace()...);
929  }
930 
931  /// Register unknown operations as dynamically legal. For operations(and
932  /// dialects) that do not have a set legalization action, treat them as
933  /// dynamically legal and invoke the given callback.
935  setLegalityCallback(fn);
936  }
937 
938  /// Register the operations of the given dialects as illegal, i.e.
939  /// operations of this dialect are not supported by the target.
940  template <typename... Names>
941  void addIllegalDialect(StringRef name, Names... names) {
942  SmallVector<StringRef, 2> dialectNames({name, names...});
944  }
945  template <typename... Args>
947  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
949  }
950 
951  //===--------------------------------------------------------------------===//
952  // Legality Querying
953  //===--------------------------------------------------------------------===//
954 
955  /// Get the legality action for the given operation.
956  std::optional<LegalizationAction> getOpAction(OperationName op) const;
957 
958  /// If the given operation instance is legal on this target, a structure
959  /// containing legality information is returned. If the operation is not
960  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
961  /// legality wasn't registered by user or dynamic legality callbacks returned
962  /// None.
963  ///
964  /// Note: Legality is actually a 4-state: Legal(recursive=true),
965  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
966  /// either as Legal or Illegal depending on context.
967  std::optional<LegalOpDetails> isLegal(Operation *op) const;
968 
969  /// Returns true is operation instance is illegal on this target. Returns
970  /// false if operation is legal, operation legality wasn't registered by user
971  /// or dynamic legality callbacks returned None.
972  bool isIllegal(Operation *op) const;
973 
974 private:
975  /// Set the dynamic legality callback for the given operation.
976  void setLegalityCallback(OperationName name,
977  const DynamicLegalityCallbackFn &callback);
978 
979  /// Set the dynamic legality callback for the given dialects.
980  void setLegalityCallback(ArrayRef<StringRef> dialects,
981  const DynamicLegalityCallbackFn &callback);
982 
983  /// Set the dynamic legality callback for the unknown ops.
984  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
985 
986  /// The set of information that configures the legalization of an operation.
987  struct LegalizationInfo {
988  /// The legality action this operation was given.
990 
991  /// If some legal instances of this operation may also be recursively legal.
992  bool isRecursivelyLegal = false;
993 
994  /// The legality callback if this operation is dynamically legal.
995  DynamicLegalityCallbackFn legalityFn;
996  };
997 
998  /// Get the legalization information for the given operation.
999  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1000 
1001  /// A deterministic mapping of operation name and its respective legality
1002  /// information.
1003  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1004 
1005  /// A set of legality callbacks for given operation names that are used to
1006  /// check if an operation instance is recursively legal.
1007  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1008 
1009  /// A deterministic mapping of dialect name to the specific legality action to
1010  /// take.
1011  llvm::StringMap<LegalizationAction> legalDialects;
1012 
1013  /// A set of dynamic legality callbacks for given dialect names.
1014  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1015 
1016  /// An optional legality callback for unknown operations.
1017  DynamicLegalityCallbackFn unknownLegalityFn;
1018 
1019  /// The current context this target applies to.
1020  MLIRContext &ctx;
1021 };
1022 
1023 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1024 //===----------------------------------------------------------------------===//
1025 // PDL Configuration
1026 //===----------------------------------------------------------------------===//
1027 
1028 /// A PDL configuration that is used to supported dialect conversion
1029 /// functionality.
1031  : public PDLPatternConfigBase<PDLConversionConfig> {
1032 public:
1033  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1034  ~PDLConversionConfig() final = default;
1035 
1036  /// Return the type converter used by this configuration, which may be nullptr
1037  /// if no type conversions are expected.
1038  const TypeConverter *getTypeConverter() const { return converter; }
1039 
1040  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1041  /// pattern.
1042  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1043  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1044 
1045 private:
1046  /// An optional type converter to use for the pattern.
1047  const TypeConverter *converter;
1048 };
1049 
1050 /// Register the dialect conversion PDL functions with the given pattern set.
1051 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1052 
1053 #else
1054 
1055 // Stubs for when PDL in rewriting is not enabled.
1056 
1057 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1058 
1059 class PDLConversionConfig final {
1060 public:
1061  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1062 };
1063 
1064 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1065 
1066 //===----------------------------------------------------------------------===//
1067 // ConversionConfig
1068 //===----------------------------------------------------------------------===//
1069 
1070 /// Dialect conversion configuration.
1072  /// An optional callback used to notify about match failure diagnostics during
1073  /// the conversion. Diagnostics reported to this callback may only be
1074  /// available in debug mode.
1076 
1077  /// Partial conversion only. All operations that are found not to be
1078  /// legalizable are placed in this set. (Note that if there is an op
1079  /// explicitly marked as illegal, the conversion terminates and the set will
1080  /// not necessarily be complete.)
1082 
1083  /// Analysis conversion only. All operations that are found to be legalizable
1084  /// are placed in this set. Note that no actual rewrites are applied to the
1085  /// IR during an analysis conversion and only pre-existing operations are
1086  /// added to the set.
1088 
1089  /// An optional listener that is notified about all IR modifications in case
1090  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1091  /// modifications are visible (i.e., they were all rolled back), no
1092  /// notifications are sent.
1093  ///
1094  /// Note: Notifications are sent in a delayed fashion, when the dialect
1095  /// conversion is guaranteed to succeed. At that point, some IR modifications
1096  /// may already have been materialized. Consequently, operations/blocks that
1097  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1098  /// guaranteed to be valid pointers and accessing op names is allowed. But
1099  /// there are no guarantees about the state of ops/blocks at the time that a
1100  /// callback is triggered.)
1101  ///
1102  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1103  /// and inserted, and later moved to another block. (Moving ops also triggers
1104  /// "notifyOperationInserted".)
1105  ///
1106  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1107  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1108  ///
1109  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1110  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1111  /// immediately performed by the dialect conversion (and rolled back upon
1112  /// failure).
1113  //
1114  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1115  // callback, the previous region/block is provided to the callback, but not
1116  // the iterator pointing to the exact location within the region/block. That
1117  // is because these notifications are sent with a delay (after the IR has
1118  // already been modified) and iterators into past IR state cannot be
1119  // represented at the moment.
1121 };
1122 
1123 //===----------------------------------------------------------------------===//
1124 // Op Conversion Entry Points
1125 //===----------------------------------------------------------------------===//
1126 
1127 /// Below we define several entry points for operation conversion. It is
1128 /// important to note that the patterns provided to the conversion framework may
1129 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1130 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1131 /// the use of the PatternRewriter.
1132 
1133 /// Apply a partial conversion on the given operations and all nested
1134 /// operations. This method converts as many operations to the target as
1135 /// possible, ignoring operations that failed to legalize. This method only
1136 /// returns failure if there ops explicitly marked as illegal.
1139  const ConversionTarget &target,
1140  const FrozenRewritePatternSet &patterns,
1141  ConversionConfig config = ConversionConfig());
1144  const FrozenRewritePatternSet &patterns,
1145  ConversionConfig config = ConversionConfig());
1146 
1147 /// Apply a complete conversion on the given operations, and all nested
1148 /// operations. This method returns failure if the conversion of any operation
1149 /// fails, or if there are unreachable blocks in any of the regions nested
1150 /// within 'ops'.
1152  const ConversionTarget &target,
1153  const FrozenRewritePatternSet &patterns,
1154  ConversionConfig config = ConversionConfig());
1156  const FrozenRewritePatternSet &patterns,
1157  ConversionConfig config = ConversionConfig());
1158 
1159 /// Apply an analysis conversion on the given operations, and all nested
1160 /// operations. This method analyzes which operations would be successfully
1161 /// converted to the target if a conversion was applied. All operations that
1162 /// were found to be legalizable to the given 'target' are placed within the
1163 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1164 /// applied to the operations on success. This method only returns failure if
1165 /// there are unreachable blocks in any of the regions nested within 'ops'.
1168  const FrozenRewritePatternSet &patterns,
1169  ConversionConfig config = ConversionConfig());
1172  const FrozenRewritePatternSet &patterns,
1173  ConversionConfig config = ConversionConfig());
1174 } // namespace mlir
1175 
1176 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:318
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:163
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:168
This struct represents a range of new types or a single value that remaps an existing signature input...