MLIR  20.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  argumentMaterializations(other.argumentMaterializations),
49  sourceMaterializations(other.sourceMaterializations),
50  targetMaterializations(other.targetMaterializations),
51  typeAttributeConversions(other.typeAttributeConversions) {}
53  conversions = other.conversions;
54  argumentMaterializations = other.argumentMaterializations;
55  sourceMaterializations = other.sourceMaterializations;
56  targetMaterializations = other.targetMaterializations;
57  typeAttributeConversions = other.typeAttributeConversions;
58  return *this;
59  }
60 
61  /// This class provides all of the information necessary to convert a type
62  /// signature.
64  public:
65  SignatureConversion(unsigned numOrigInputs)
66  : remappedInputs(numOrigInputs) {}
67 
68  /// This struct represents a range of new types or a single value that
69  /// remaps an existing signature input.
70  struct InputMapping {
71  size_t inputNo, size;
73  };
74 
75  /// Return the argument types for the new signature.
76  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
77 
78  /// Get the input mapping for the given argument.
79  std::optional<InputMapping> getInputMapping(unsigned input) const {
80  return remappedInputs[input];
81  }
82 
83  //===------------------------------------------------------------------===//
84  // Conversion Hooks
85  //===------------------------------------------------------------------===//
86 
87  /// Remap an input of the original signature with a new set of types. The
88  /// new types are appended to the new signature conversion.
89  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
90 
91  /// Append new input types to the signature conversion, this should only be
92  /// used if the new types are not intended to remap an existing input.
93  void addInputs(ArrayRef<Type> types);
94 
95  /// Remap an input of the original signature to another `replacement`
96  /// value. This drops the original argument.
97  void remapInput(unsigned origInputNo, Value replacement);
98 
99  private:
100  /// Remap an input of the original signature with a range of types in the
101  /// new signature.
102  void remapInput(unsigned origInputNo, unsigned newInputNo,
103  unsigned newInputCount = 1);
104 
105  /// The remapping information for each of the original arguments.
106  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
107 
108  /// The set of new argument types.
109  SmallVector<Type, 4> argTypes;
110  };
111 
112  /// The general result of a type attribute conversion callback, allowing
113  /// for early termination. The default constructor creates the na case.
115  public:
116  constexpr AttributeConversionResult() : impl() {}
117  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
118 
120  static AttributeConversionResult na();
122 
123  bool hasResult() const;
124  bool isNa() const;
125  bool isAbort() const;
126 
127  Attribute getResult() const;
128 
129  private:
130  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
131 
132  llvm::PointerIntPair<Attribute, 2> impl;
133  // Note that na is 0 so that we can use PointerIntPair's default
134  // constructor.
135  static constexpr unsigned naTag = 0;
136  static constexpr unsigned resultTag = 1;
137  static constexpr unsigned abortTag = 2;
138  };
139 
140  /// Register a conversion function. A conversion function must be convertible
141  /// to any of the following forms(where `T` is a class derived from `Type`:
142  /// * std::optional<Type>(T)
143  /// - This form represents a 1-1 type conversion. It should return nullptr
144  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
145  /// the converter is allowed to try another conversion function to
146  /// perform the conversion.
147  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
148  /// - This form represents a 1-N type conversion. It should return
149  /// `failure` or `std::nullopt` to signify a failed conversion. If the
150  /// new set of types is empty, the type is removed and any usages of the
151  /// existing value are expected to be removed during conversion. If
152  /// `std::nullopt` is returned, the converter is allowed to try another
153  /// conversion function to perform the conversion.
154  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
155  /// ArrayRef<Type>)
156  /// - This form represents a 1-N type conversion supporting recursive
157  /// types. The first two arguments and the return value are the same as
158  /// for the regular 1-N form. The third argument is contains is the
159  /// "call stack" of the recursive conversion: it contains the list of
160  /// types currently being converted, with the current type being the
161  /// last one. If it is present more than once in the list, the
162  /// conversion concerns a recursive type.
163  /// Note: When attempting to convert a type, e.g. via 'convertType', the
164  /// mostly recently added conversions will be invoked first.
165  template <typename FnT, typename T = typename llvm::function_traits<
166  std::decay_t<FnT>>::template arg_t<0>>
167  void addConversion(FnT &&callback) {
168  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
169  }
170 
171  /// All of the following materializations require function objects that are
172  /// convertible to the following form:
173  /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
174  /// where `T` is any subclass of `Type`. This function is responsible for
175  /// creating an operation, using the OpBuilder and Location provided, that
176  /// "casts" a range of values into a single value of the given type `T`. It
177  /// must return a Value of the type `T` on success, an `std::nullopt` if
178  /// it failed but other materialization can be attempted, and `nullptr` on
179  /// unrecoverable failure. Materialization functions must be provided when a
180  /// type conversion may persist after the conversion has finished.
181 
182  /// This method registers a materialization that will be called when
183  /// converting (potentially multiple) block arguments that were the result of
184  /// a signature conversion of a single block argument, to a single SSA value
185  /// with the old block argument type.
186  template <typename FnT, typename T = typename llvm::function_traits<
187  std::decay_t<FnT>>::template arg_t<1>>
188  void addArgumentMaterialization(FnT &&callback) {
189  argumentMaterializations.emplace_back(
190  wrapMaterialization<T>(std::forward<FnT>(callback)));
191  }
192 
193  /// This method registers a materialization that will be called when
194  /// converting a legal replacement value back to an illegal source type.
195  /// This is used when some uses of the original, illegal value must persist
196  /// beyond the main conversion.
197  template <typename FnT, typename T = typename llvm::function_traits<
198  std::decay_t<FnT>>::template arg_t<1>>
199  void addSourceMaterialization(FnT &&callback) {
200  sourceMaterializations.emplace_back(
201  wrapMaterialization<T>(std::forward<FnT>(callback)));
202  }
203 
204  /// This method registers a materialization that will be called when
205  /// converting an illegal (source) value to a legal (target) type.
206  template <typename FnT, typename T = typename llvm::function_traits<
207  std::decay_t<FnT>>::template arg_t<1>>
208  void addTargetMaterialization(FnT &&callback) {
209  targetMaterializations.emplace_back(
210  wrapMaterialization<T>(std::forward<FnT>(callback)));
211  }
212 
213  /// Register a conversion function for attributes within types. Type
214  /// converters may call this function in order to allow hoking into the
215  /// translation of attributes that exist within types. For example, a type
216  /// converter for the `memref` type could use these conversions to convert
217  /// memory spaces or layouts in an extensible way.
218  ///
219  /// The conversion functions take a non-null Type or subclass of Type and a
220  /// non-null Attribute (or subclass of Attribute), and returns a
221  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
222  /// which may be `nullptr`, representing the conversion's success,
223  /// `AttributeConversionResult::na()` (the default empty value), indicating
224  /// that the conversion function did not apply and that further conversion
225  /// functions should be checked, or `AttributeConversionResult::abort()`
226  /// indicating that the conversion process should be aborted.
227  ///
228  /// Registered conversion functions are callled in the reverse of the order in
229  /// which they were registered.
230  template <
231  typename FnT,
232  typename T =
233  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
234  typename A =
235  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
236  void addTypeAttributeConversion(FnT &&callback) {
237  registerTypeAttributeConversion(
238  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
239  }
240 
241  /// Convert the given type. This function should return failure if no valid
242  /// conversion exists, success otherwise. If the new set of types is empty,
243  /// the type is removed and any usages of the existing value are expected to
244  /// be removed during conversion.
245  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
246 
247  /// This hook simplifies defining 1-1 type conversions. This function returns
248  /// the type to convert to on success, and a null type on failure.
249  Type convertType(Type t) const;
250 
251  /// Attempts a 1-1 type conversion, expecting the result type to be
252  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
253  /// and a null type on conversion or cast failure.
254  template <typename TargetType>
255  TargetType convertType(Type t) const {
256  return dyn_cast_or_null<TargetType>(convertType(t));
257  }
258 
259  /// Convert the given set of types, filling 'results' as necessary. This
260  /// returns failure if the conversion of any of the types fails, success
261  /// otherwise.
262  LogicalResult convertTypes(TypeRange types,
263  SmallVectorImpl<Type> &results) const;
264 
265  /// Return true if the given type is legal for this type converter, i.e. the
266  /// type converts to itself.
267  bool isLegal(Type type) const;
268 
269  /// Return true if all of the given types are legal for this type converter.
270  template <typename RangeT>
271  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
272  !std::is_convertible<RangeT, Operation *>::value,
273  bool>
274  isLegal(RangeT &&range) const {
275  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
276  }
277  /// Return true if the given operation has legal operand and result types.
278  bool isLegal(Operation *op) const;
279 
280  /// Return true if the types of block arguments within the region are legal.
281  bool isLegal(Region *region) const;
282 
283  /// Return true if the inputs and outputs of the given function type are
284  /// legal.
285  bool isSignatureLegal(FunctionType ty) const;
286 
287  /// This method allows for converting a specific argument of a signature. It
288  /// takes as inputs the original argument input number, type.
289  /// On success, it populates 'result' with any new mappings.
290  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
291  SignatureConversion &result) const;
292  LogicalResult convertSignatureArgs(TypeRange types,
293  SignatureConversion &result,
294  unsigned origInputOffset = 0) const;
295 
296  /// This function converts the type signature of the given block, by invoking
297  /// 'convertSignatureArg' for each argument. This function should return a
298  /// valid conversion for the signature on success, std::nullopt otherwise.
299  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
300 
301  /// Materialize a conversion from a set of types into one result type by
302  /// generating a cast sequence of some kind. See the respective
303  /// `add*Materialization` for more information on the context for these
304  /// methods.
306  Type resultType,
307  ValueRange inputs) const {
308  return materializeConversion(argumentMaterializations, builder, loc,
309  resultType, inputs);
310  }
312  Type resultType, ValueRange inputs) const {
313  return materializeConversion(sourceMaterializations, builder, loc,
314  resultType, inputs);
315  }
317  Type resultType, ValueRange inputs) const {
318  return materializeConversion(targetMaterializations, builder, loc,
319  resultType, inputs);
320  }
321 
322  /// Convert an attribute present `attr` from within the type `type` using
323  /// the registered conversion functions. If no applicable conversion has been
324  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
325  /// is a valid return value for this function.
326  std::optional<Attribute> convertTypeAttribute(Type type,
327  Attribute attr) const;
328 
329 private:
330  /// The signature of the callback used to convert a type. If the new set of
331  /// types is empty, the type is removed and any usages of the existing value
332  /// are expected to be removed during conversion.
333  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
335 
336  /// The signature of the callback used to materialize a conversion.
337  using MaterializationCallbackFn = std::function<std::optional<Value>(
339 
340  /// The signature of the callback used to convert a type attribute.
341  using TypeAttributeConversionCallbackFn =
342  std::function<AttributeConversionResult(Type, Attribute)>;
343 
344  /// Attempt to materialize a conversion using one of the provided
345  /// materialization functions.
346  Value
347  materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
348  OpBuilder &builder, Location loc, Type resultType,
349  ValueRange inputs) const;
350 
351  /// Generate a wrapper for the given callback. This allows for accepting
352  /// different callback forms, that all compose into a single version.
353  /// With callback of form: `std::optional<Type>(T)`
354  template <typename T, typename FnT>
355  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
356  wrapCallback(FnT &&callback) const {
357  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
358  T type, SmallVectorImpl<Type> &results) {
359  if (std::optional<Type> resultOpt = callback(type)) {
360  bool wasSuccess = static_cast<bool>(*resultOpt);
361  if (wasSuccess)
362  results.push_back(*resultOpt);
363  return std::optional<LogicalResult>(success(wasSuccess));
364  }
365  return std::optional<LogicalResult>();
366  });
367  }
368  /// With callback of form: `std::optional<LogicalResult>(
369  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
370  template <typename T, typename FnT>
371  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
372  ConversionCallbackFn>
373  wrapCallback(FnT &&callback) const {
374  return [callback = std::forward<FnT>(callback)](
375  Type type,
376  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
377  T derivedType = dyn_cast<T>(type);
378  if (!derivedType)
379  return std::nullopt;
380  return callback(derivedType, results);
381  };
382  }
383 
384  /// Register a type conversion.
385  void registerConversion(ConversionCallbackFn callback) {
386  conversions.emplace_back(std::move(callback));
387  cachedDirectConversions.clear();
388  cachedMultiConversions.clear();
389  }
390 
391  /// Generate a wrapper for the given materialization callback. The callback
392  /// may take any subclass of `Type` and the wrapper will check for the target
393  /// type to be of the expected class before calling the callback.
394  template <typename T, typename FnT>
395  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
396  return [callback = std::forward<FnT>(callback)](
397  OpBuilder &builder, Type resultType, ValueRange inputs,
398  Location loc) -> std::optional<Value> {
399  if (T derivedType = dyn_cast<T>(resultType))
400  return callback(builder, derivedType, inputs, loc);
401  return std::nullopt;
402  };
403  }
404 
405  /// Generate a wrapper for the given memory space conversion callback. The
406  /// callback may take any subclass of `Attribute` and the wrapper will check
407  /// for the target attribute to be of the expected class before calling the
408  /// callback.
409  template <typename T, typename A, typename FnT>
410  TypeAttributeConversionCallbackFn
411  wrapTypeAttributeConversion(FnT &&callback) const {
412  return [callback = std::forward<FnT>(callback)](
413  Type type, Attribute attr) -> AttributeConversionResult {
414  if (T derivedType = dyn_cast<T>(type)) {
415  if (A derivedAttr = dyn_cast_or_null<A>(attr))
416  return callback(derivedType, derivedAttr);
417  }
419  };
420  }
421 
422  /// Register a memory space conversion, clearing caches.
423  void
424  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
425  typeAttributeConversions.emplace_back(std::move(callback));
426  // Clear type conversions in case a memory space is lingering inside.
427  cachedDirectConversions.clear();
428  cachedMultiConversions.clear();
429  }
430 
431  /// The set of registered conversion functions.
432  SmallVector<ConversionCallbackFn, 4> conversions;
433 
434  /// The list of registered materialization functions.
435  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
436  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
437  SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
438 
439  /// The list of registered type attribute conversion functions.
440  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
441 
442  /// A set of cached conversions to avoid recomputing in the common case.
443  /// Direct 1-1 conversions are the most common, so this cache stores the
444  /// successful 1-1 conversions as well as all failed conversions.
445  mutable DenseMap<Type, Type> cachedDirectConversions;
446  /// This cache stores the successful 1->N conversions, where N != 1.
447  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
448  /// A mutex used for cache access
449  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
450 };
451 
452 //===----------------------------------------------------------------------===//
453 // Conversion Patterns
454 //===----------------------------------------------------------------------===//
455 
456 /// Base class for the conversion patterns. This pattern class enables type
457 /// conversions, and other uses specific to the conversion framework. As such,
458 /// patterns of this type can only be used with the 'apply*' methods below.
460 public:
461  /// Hook for derived classes to implement rewriting. `op` is the (first)
462  /// operation matched by the pattern, `operands` is a list of the rewritten
463  /// operand values that are passed to `op`, `rewriter` can be used to emit the
464  /// new operations. This function should not fail. If some specific cases of
465  /// the operation are not supported, these cases should not be matched.
466  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
467  ConversionPatternRewriter &rewriter) const {
468  llvm_unreachable("unimplemented rewrite");
469  }
470 
471  /// Hook for derived classes to implement combined matching and rewriting.
472  virtual LogicalResult
474  ConversionPatternRewriter &rewriter) const {
475  if (failed(match(op)))
476  return failure();
477  rewrite(op, operands, rewriter);
478  return success();
479  }
480 
481  /// Attempt to match and rewrite the IR root at the specified operation.
482  LogicalResult matchAndRewrite(Operation *op,
483  PatternRewriter &rewriter) const final;
484 
485  /// Return the type converter held by this pattern, or nullptr if the pattern
486  /// does not require type conversion.
487  const TypeConverter *getTypeConverter() const { return typeConverter; }
488 
489  template <typename ConverterTy>
490  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
491  const ConverterTy *>
493  return static_cast<const ConverterTy *>(typeConverter);
494  }
495 
496 protected:
497  /// See `RewritePattern::RewritePattern` for information on the other
498  /// available constructors.
499  using RewritePattern::RewritePattern;
500  /// Construct a conversion pattern with the given converter, and forward the
501  /// remaining arguments to RewritePattern.
502  template <typename... Args>
504  : RewritePattern(std::forward<Args>(args)...),
506 
507 protected:
508  /// An optional type converter for use by this pattern.
509  const TypeConverter *typeConverter = nullptr;
510 
511 private:
513 };
514 
515 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
516 /// matching and rewriting against an instance of a derived operation class as
517 /// opposed to a raw Operation.
518 template <typename SourceOp>
520 public:
521  using OpAdaptor = typename SourceOp::Adaptor;
522 
524  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
526  PatternBenefit benefit = 1)
527  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
528  context) {}
529 
530  /// Wrappers around the ConversionPattern methods that pass the derived op
531  /// type.
532  LogicalResult match(Operation *op) const final {
533  return match(cast<SourceOp>(op));
534  }
535  void rewrite(Operation *op, ArrayRef<Value> operands,
536  ConversionPatternRewriter &rewriter) const final {
537  auto sourceOp = cast<SourceOp>(op);
538  rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
539  }
540  LogicalResult
542  ConversionPatternRewriter &rewriter) const final {
543  auto sourceOp = cast<SourceOp>(op);
544  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
545  }
546 
547  /// Rewrite and Match methods that operate on the SourceOp type. These must be
548  /// overridden by the derived pattern class.
549  virtual LogicalResult match(SourceOp op) const {
550  llvm_unreachable("must override match or matchAndRewrite");
551  }
552  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
553  ConversionPatternRewriter &rewriter) const {
554  llvm_unreachable("must override matchAndRewrite or a rewrite method");
555  }
556  virtual LogicalResult
557  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
558  ConversionPatternRewriter &rewriter) const {
559  if (failed(match(op)))
560  return failure();
561  rewrite(op, adaptor, rewriter);
562  return success();
563  }
564 
565 private:
567 };
568 
569 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
570 /// allows for matching and rewriting against an instance of an OpInterface
571 /// class as opposed to a raw Operation.
572 template <typename SourceOp>
574 public:
577  SourceOp::getInterfaceID(), benefit, context) {}
579  MLIRContext *context, PatternBenefit benefit = 1)
581  SourceOp::getInterfaceID(), benefit, context) {}
582 
583  /// Wrappers around the ConversionPattern methods that pass the derived op
584  /// type.
585  void rewrite(Operation *op, ArrayRef<Value> operands,
586  ConversionPatternRewriter &rewriter) const final {
587  rewrite(cast<SourceOp>(op), operands, rewriter);
588  }
589  LogicalResult
591  ConversionPatternRewriter &rewriter) const final {
592  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
593  }
594 
595  /// Rewrite and Match methods that operate on the SourceOp type. These must be
596  /// overridden by the derived pattern class.
597  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
598  ConversionPatternRewriter &rewriter) const {
599  llvm_unreachable("must override matchAndRewrite or a rewrite method");
600  }
601  virtual LogicalResult
602  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
603  ConversionPatternRewriter &rewriter) const {
604  if (failed(match(op)))
605  return failure();
606  rewrite(op, operands, rewriter);
607  return success();
608  }
609 
610 private:
612 };
613 
614 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
615 /// for matching and rewriting against instances of an operation that possess a
616 /// given trait.
617 template <template <typename> class TraitType>
619 public:
622  TypeID::get<TraitType>(), benefit, context) {}
624  MLIRContext *context, PatternBenefit benefit = 1)
626  TypeID::get<TraitType>(), benefit, context) {}
627 };
628 
629 /// Generic utility to convert op result types according to type converter
630 /// without knowing exact op type.
631 /// Clones existing op with new result types and returns it.
632 FailureOr<Operation *>
633 convertOpResultTypes(Operation *op, ValueRange operands,
634  const TypeConverter &converter,
635  ConversionPatternRewriter &rewriter);
636 
637 /// Add a pattern to the given pattern list to convert the signature of a
638 /// FunctionOpInterface op with the given type converter. This only supports
639 /// ops which use FunctionType to represent their type.
641  StringRef functionLikeOpName, RewritePatternSet &patterns,
642  const TypeConverter &converter);
643 
644 template <typename FuncOpT>
646  RewritePatternSet &patterns, const TypeConverter &converter) {
647  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
648  patterns, converter);
649 }
650 
652  RewritePatternSet &patterns, const TypeConverter &converter);
653 
654 //===----------------------------------------------------------------------===//
655 // Conversion PatternRewriter
656 //===----------------------------------------------------------------------===//
657 
658 namespace detail {
659 struct ConversionPatternRewriterImpl;
660 } // namespace detail
661 
662 /// This class implements a pattern rewriter for use with ConversionPatterns. It
663 /// extends the base PatternRewriter and provides special conversion specific
664 /// hooks.
666 public:
668 
669  /// Apply a signature conversion to given block. This replaces the block with
670  /// a new block containing the updated signature. The operations of the given
671  /// block are inlined into the newly-created block, which is returned.
672  ///
673  /// If no block argument types are changing, the original block will be
674  /// left in place and returned.
675  ///
676  /// A signature converison must be provided. (Type converters can construct
677  /// a signature conversion with `convertBlockSignature`.)
678  ///
679  /// Optionally, a type converter can be provided to build materializations.
680  /// Note: If no type converter was provided or the type converter does not
681  /// specify any suitable argument/target materialization rules, the dialect
682  /// conversion may fail to legalize unresolved materializations.
683  Block *
686  const TypeConverter *converter = nullptr);
687 
688  /// Apply a signature conversion to each block in the given region. This
689  /// replaces each block with a new block containing the updated signature. If
690  /// an updated signature would match the current signature, the respective
691  /// block is left in place as is. (See `applySignatureConversion` for
692  /// details.) The new entry block of the region is returned.
693  ///
694  /// SignatureConversions are computed with the specified type converter.
695  /// This function returns "failure" if the type converter failed to compute
696  /// a SignatureConversion for at least one block.
697  ///
698  /// Optionally, a special SignatureConversion can be specified for the entry
699  /// block. This is because the types of the entry block arguments are often
700  /// tied semantically to the operation.
701  FailureOr<Block *> convertRegionTypes(
702  Region *region, const TypeConverter &converter,
703  TypeConverter::SignatureConversion *entryConversion = nullptr);
704 
705  /// Replace all the uses of the block argument `from` with value `to`.
707 
708  /// Return the converted value of 'key' with a type defined by the type
709  /// converter of the currently executing pattern. Return nullptr in the case
710  /// of failure, the remapped value otherwise.
712 
713  /// Return the converted values that replace 'keys' with types defined by the
714  /// type converter of the currently executing pattern. Returns failure if the
715  /// remap failed, success otherwise.
716  LogicalResult getRemappedValues(ValueRange keys,
717  SmallVectorImpl<Value> &results);
718 
719  //===--------------------------------------------------------------------===//
720  // PatternRewriter Hooks
721  //===--------------------------------------------------------------------===//
722 
723  /// Indicate that the conversion rewriter can recover from rewrite failure.
724  /// Recovery is supported via rollback, allowing for continued processing of
725  /// patterns even if a failure is encountered during the rewrite step.
726  bool canRecoverFromRewriteFailure() const override { return true; }
727 
728  /// PatternRewriter hook for replacing an operation.
729  void replaceOp(Operation *op, ValueRange newValues) override;
730 
731  /// PatternRewriter hook for replacing an operation.
732  void replaceOp(Operation *op, Operation *newOp) override;
733 
734  /// PatternRewriter hook for erasing a dead operation. The uses of this
735  /// operation *must* be made dead by the end of the conversion process,
736  /// otherwise an assert will be issued.
737  void eraseOp(Operation *op) override;
738 
739  /// PatternRewriter hook for erase all operations in a block. This is not yet
740  /// implemented for dialect conversion.
741  void eraseBlock(Block *block) override;
742 
743  /// PatternRewriter hook for inlining the ops of a block into another block.
744  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
745  ValueRange argValues = std::nullopt) override;
747 
748  /// PatternRewriter hook for updating the given operation in-place.
749  /// Note: These methods only track updates to the given operation itself,
750  /// and not nested regions. Updates to regions will still require notification
751  /// through other more specific hooks above.
752  void startOpModification(Operation *op) override;
753 
754  /// PatternRewriter hook for updating the given operation in-place.
755  void finalizeOpModification(Operation *op) override;
756 
757  /// PatternRewriter hook for updating the given operation in-place.
758  void cancelOpModification(Operation *op) override;
759 
760  /// Return a reference to the internal implementation.
762 
763 private:
764  // Allow OperationConverter to construct new rewriters.
765  friend struct OperationConverter;
766 
767  /// Conversion pattern rewriters must not be used outside of dialect
768  /// conversions. They apply some IR rewrites in a delayed fashion and could
769  /// bring the IR into an inconsistent state when used standalone.
771  const ConversionConfig &config);
772 
773  // Hide unsupported pattern rewriter API.
775 
776  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
777 };
778 
779 //===----------------------------------------------------------------------===//
780 // ConversionTarget
781 //===----------------------------------------------------------------------===//
782 
783 /// This class describes a specific conversion target.
785 public:
786  /// This enumeration corresponds to the specific action to take when
787  /// considering an operation legal for this conversion target.
788  enum class LegalizationAction {
789  /// The target supports this operation.
790  Legal,
791 
792  /// This operation has dynamic legalization constraints that must be checked
793  /// by the target.
794  Dynamic,
795 
796  /// The target explicitly does not support this operation.
797  Illegal,
798  };
799 
800  /// A structure containing additional information describing a specific legal
801  /// operation instance.
802  struct LegalOpDetails {
803  /// A flag that indicates if this operation is 'recursively' legal. This
804  /// means that if an operation is legal, either statically or dynamically,
805  /// all of the operations nested within are also considered legal.
806  bool isRecursivelyLegal = false;
807  };
808 
809  /// The signature of the callback used to determine if an operation is
810  /// dynamically legal on the target.
812  std::function<std::optional<bool>(Operation *)>;
813 
814  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
815  virtual ~ConversionTarget() = default;
816 
817  //===--------------------------------------------------------------------===//
818  // Legality Registration
819  //===--------------------------------------------------------------------===//
820 
821  /// Register a legality action for the given operation.
823  template <typename OpT>
825  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
826  }
827 
828  /// Register the given operations as legal.
831  }
832  template <typename OpT>
833  void addLegalOp() {
834  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
835  }
836  template <typename OpT, typename OpT2, typename... OpTs>
837  void addLegalOp() {
838  addLegalOp<OpT>();
839  addLegalOp<OpT2, OpTs...>();
840  }
841 
842  /// Register the given operation as dynamically legal and set the dynamic
843  /// legalization callback to the one provided.
845  const DynamicLegalityCallbackFn &callback) {
847  setLegalityCallback(op, callback);
848  }
849  template <typename OpT>
851  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
852  callback);
853  }
854  template <typename OpT, typename OpT2, typename... OpTs>
856  addDynamicallyLegalOp<OpT>(callback);
857  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
858  }
859  template <typename OpT, class Callable>
860  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
861  addDynamicallyLegalOp(Callable &&callback) {
862  addDynamicallyLegalOp<OpT>(
863  [=](Operation *op) { return callback(cast<OpT>(op)); });
864  }
865 
866  /// Register the given operation as illegal, i.e. this operation is known to
867  /// not be supported by this target.
870  }
871  template <typename OpT>
872  void addIllegalOp() {
873  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
874  }
875  template <typename OpT, typename OpT2, typename... OpTs>
876  void addIllegalOp() {
877  addIllegalOp<OpT>();
878  addIllegalOp<OpT2, OpTs...>();
879  }
880 
881  /// Mark an operation, that *must* have either been set as `Legal` or
882  /// `DynamicallyLegal`, as being recursively legal. This means that in
883  /// addition to the operation itself, all of the operations nested within are
884  /// also considered legal. An optional dynamic legality callback may be
885  /// provided to mark subsets of legal instances as recursively legal.
887  const DynamicLegalityCallbackFn &callback);
888  template <typename OpT>
890  OperationName opName(OpT::getOperationName(), &ctx);
891  markOpRecursivelyLegal(opName, callback);
892  }
893  template <typename OpT, typename OpT2, typename... OpTs>
895  markOpRecursivelyLegal<OpT>(callback);
896  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
897  }
898  template <typename OpT, class Callable>
899  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
900  markOpRecursivelyLegal(Callable &&callback) {
901  markOpRecursivelyLegal<OpT>(
902  [=](Operation *op) { return callback(cast<OpT>(op)); });
903  }
904 
905  /// Register a legality action for the given dialects.
906  void setDialectAction(ArrayRef<StringRef> dialectNames,
907  LegalizationAction action);
908 
909  /// Register the operations of the given dialects as legal.
910  template <typename... Names>
911  void addLegalDialect(StringRef name, Names... names) {
912  SmallVector<StringRef, 2> dialectNames({name, names...});
914  }
915  template <typename... Args>
917  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
919  }
920 
921  /// Register the operations of the given dialects as dynamically legal, i.e.
922  /// requiring custom handling by the callback.
923  template <typename... Names>
925  StringRef name, Names... names) {
926  SmallVector<StringRef, 2> dialectNames({name, names...});
928  setLegalityCallback(dialectNames, callback);
929  }
930  template <typename... Args>
932  addDynamicallyLegalDialect(std::move(callback),
933  Args::getDialectNamespace()...);
934  }
935 
936  /// Register unknown operations as dynamically legal. For operations(and
937  /// dialects) that do not have a set legalization action, treat them as
938  /// dynamically legal and invoke the given callback.
940  setLegalityCallback(fn);
941  }
942 
943  /// Register the operations of the given dialects as illegal, i.e.
944  /// operations of this dialect are not supported by the target.
945  template <typename... Names>
946  void addIllegalDialect(StringRef name, Names... names) {
947  SmallVector<StringRef, 2> dialectNames({name, names...});
949  }
950  template <typename... Args>
952  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
954  }
955 
956  //===--------------------------------------------------------------------===//
957  // Legality Querying
958  //===--------------------------------------------------------------------===//
959 
960  /// Get the legality action for the given operation.
961  std::optional<LegalizationAction> getOpAction(OperationName op) const;
962 
963  /// If the given operation instance is legal on this target, a structure
964  /// containing legality information is returned. If the operation is not
965  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
966  /// legality wasn't registered by user or dynamic legality callbacks returned
967  /// None.
968  ///
969  /// Note: Legality is actually a 4-state: Legal(recursive=true),
970  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
971  /// either as Legal or Illegal depending on context.
972  std::optional<LegalOpDetails> isLegal(Operation *op) const;
973 
974  /// Returns true is operation instance is illegal on this target. Returns
975  /// false if operation is legal, operation legality wasn't registered by user
976  /// or dynamic legality callbacks returned None.
977  bool isIllegal(Operation *op) const;
978 
979 private:
980  /// Set the dynamic legality callback for the given operation.
981  void setLegalityCallback(OperationName name,
982  const DynamicLegalityCallbackFn &callback);
983 
984  /// Set the dynamic legality callback for the given dialects.
985  void setLegalityCallback(ArrayRef<StringRef> dialects,
986  const DynamicLegalityCallbackFn &callback);
987 
988  /// Set the dynamic legality callback for the unknown ops.
989  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
990 
991  /// The set of information that configures the legalization of an operation.
992  struct LegalizationInfo {
993  /// The legality action this operation was given.
995 
996  /// If some legal instances of this operation may also be recursively legal.
997  bool isRecursivelyLegal = false;
998 
999  /// The legality callback if this operation is dynamically legal.
1000  DynamicLegalityCallbackFn legalityFn;
1001  };
1002 
1003  /// Get the legalization information for the given operation.
1004  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1005 
1006  /// A deterministic mapping of operation name and its respective legality
1007  /// information.
1008  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1009 
1010  /// A set of legality callbacks for given operation names that are used to
1011  /// check if an operation instance is recursively legal.
1012  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1013 
1014  /// A deterministic mapping of dialect name to the specific legality action to
1015  /// take.
1016  llvm::StringMap<LegalizationAction> legalDialects;
1017 
1018  /// A set of dynamic legality callbacks for given dialect names.
1019  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1020 
1021  /// An optional legality callback for unknown operations.
1022  DynamicLegalityCallbackFn unknownLegalityFn;
1023 
1024  /// The current context this target applies to.
1025  MLIRContext &ctx;
1026 };
1027 
1028 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1029 //===----------------------------------------------------------------------===//
1030 // PDL Configuration
1031 //===----------------------------------------------------------------------===//
1032 
1033 /// A PDL configuration that is used to supported dialect conversion
1034 /// functionality.
1036  : public PDLPatternConfigBase<PDLConversionConfig> {
1037 public:
1038  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1039  ~PDLConversionConfig() final = default;
1040 
1041  /// Return the type converter used by this configuration, which may be nullptr
1042  /// if no type conversions are expected.
1043  const TypeConverter *getTypeConverter() const { return converter; }
1044 
1045  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1046  /// pattern.
1047  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1048  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1049 
1050 private:
1051  /// An optional type converter to use for the pattern.
1052  const TypeConverter *converter;
1053 };
1054 
1055 /// Register the dialect conversion PDL functions with the given pattern set.
1056 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1057 
1058 #else
1059 
1060 // Stubs for when PDL in rewriting is not enabled.
1061 
1062 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1063 
1064 class PDLConversionConfig final {
1065 public:
1066  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1067 };
1068 
1069 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1070 
1071 //===----------------------------------------------------------------------===//
1072 // ConversionConfig
1073 //===----------------------------------------------------------------------===//
1074 
1075 /// Dialect conversion configuration.
1077  /// An optional callback used to notify about match failure diagnostics during
1078  /// the conversion. Diagnostics reported to this callback may only be
1079  /// available in debug mode.
1081 
1082  /// Partial conversion only. All operations that are found not to be
1083  /// legalizable are placed in this set. (Note that if there is an op
1084  /// explicitly marked as illegal, the conversion terminates and the set will
1085  /// not necessarily be complete.)
1087 
1088  /// Analysis conversion only. All operations that are found to be legalizable
1089  /// are placed in this set. Note that no actual rewrites are applied to the
1090  /// IR during an analysis conversion and only pre-existing operations are
1091  /// added to the set.
1093 
1094  /// An optional listener that is notified about all IR modifications in case
1095  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1096  /// modifications are visible (i.e., they were all rolled back), or if the
1097  /// dialect conversion is an "analysis conversion", no notifications are
1098  /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1099  ///
1100  /// Note: Notifications are sent in a delayed fashion, when the dialect
1101  /// conversion is guaranteed to succeed. At that point, some IR modifications
1102  /// may already have been materialized. Consequently, operations/blocks that
1103  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1104  /// guaranteed to be valid pointers and accessing op names is allowed. But
1105  /// there are no guarantees about the state of ops/blocks at the time that a
1106  /// callback is triggered.)
1107  ///
1108  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1109  /// and inserted, and later moved to another block. (Moving ops also triggers
1110  /// "notifyOperationInserted".)
1111  ///
1112  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1113  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1114  ///
1115  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1116  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1117  /// immediately performed by the dialect conversion (and rolled back upon
1118  /// failure).
1119  //
1120  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1121  // callback, the previous region/block is provided to the callback, but not
1122  // the iterator pointing to the exact location within the region/block. That
1123  // is because these notifications are sent with a delay (after the IR has
1124  // already been modified) and iterators into past IR state cannot be
1125  // represented at the moment.
1127 
1128  /// If set to "true", the dialect conversion attempts to build source/target/
1129  /// argument materializations through the type converter API in lieu of
1130  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1131  /// at least one materialization could not be built.
1132  ///
1133  /// If set to "false", the dialect conversion does not build any custom
1134  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1135  /// ops to ensure that the resulting IR is valid.
1137 };
1138 
1139 //===----------------------------------------------------------------------===//
1140 // Reconcile Unrealized Casts
1141 //===----------------------------------------------------------------------===//
1142 
1143 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1144 /// left-over ops in `remainingCastOps` (if provided).
1145 ///
1146 /// This function processes cast ops in a worklist-driven fashion. For each
1147 /// cast op, if the chain of input casts eventually reaches a cast op where the
1148 /// input types match the output types of the matched op, replace the matched
1149 /// op with the inputs.
1150 ///
1151 /// Example:
1152 /// %1 = unrealized_conversion_cast %0 : !A to !B
1153 /// %2 = unrealized_conversion_cast %1 : !B to !C
1154 /// %3 = unrealized_conversion_cast %2 : !C to !A
1155 ///
1156 /// In the above example, %0 can be used instead of %3 and all cast ops are
1157 /// folded away.
1160  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1161 
1162 //===----------------------------------------------------------------------===//
1163 // Op Conversion Entry Points
1164 //===----------------------------------------------------------------------===//
1165 
1166 /// Below we define several entry points for operation conversion. It is
1167 /// important to note that the patterns provided to the conversion framework may
1168 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1169 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1170 /// the use of the PatternRewriter.
1171 
1172 /// Apply a partial conversion on the given operations and all nested
1173 /// operations. This method converts as many operations to the target as
1174 /// possible, ignoring operations that failed to legalize. This method only
1175 /// returns failure if there ops explicitly marked as illegal.
1176 LogicalResult
1178  const ConversionTarget &target,
1179  const FrozenRewritePatternSet &patterns,
1180  ConversionConfig config = ConversionConfig());
1181 LogicalResult
1183  const FrozenRewritePatternSet &patterns,
1184  ConversionConfig config = ConversionConfig());
1185 
1186 /// Apply a complete conversion on the given operations, and all nested
1187 /// operations. This method returns failure if the conversion of any operation
1188 /// fails, or if there are unreachable blocks in any of the regions nested
1189 /// within 'ops'.
1190 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1191  const ConversionTarget &target,
1192  const FrozenRewritePatternSet &patterns,
1193  ConversionConfig config = ConversionConfig());
1194 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1195  const FrozenRewritePatternSet &patterns,
1196  ConversionConfig config = ConversionConfig());
1197 
1198 /// Apply an analysis conversion on the given operations, and all nested
1199 /// operations. This method analyzes which operations would be successfully
1200 /// converted to the target if a conversion was applied. All operations that
1201 /// were found to be legalizable to the given 'target' are placed within the
1202 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1203 /// applied to the operations on success. This method only returns failure if
1204 /// there are unreachable blocks in any of the regions nested within 'ops'.
1205 LogicalResult
1207  const FrozenRewritePatternSet &patterns,
1208  ConversionConfig config = ConversionConfig());
1209 LogicalResult
1211  const FrozenRewritePatternSet &patterns,
1212  ConversionConfig config = ConversionConfig());
1213 } // namespace mlir
1214 
1215 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:324
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
This struct represents a range of new types or a single value that remaps an existing signature input...