MLIR  21.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  argumentMaterializations(other.argumentMaterializations),
49  sourceMaterializations(other.sourceMaterializations),
50  targetMaterializations(other.targetMaterializations),
51  typeAttributeConversions(other.typeAttributeConversions) {}
53  conversions = other.conversions;
54  argumentMaterializations = other.argumentMaterializations;
55  sourceMaterializations = other.sourceMaterializations;
56  targetMaterializations = other.targetMaterializations;
57  typeAttributeConversions = other.typeAttributeConversions;
58  return *this;
59  }
60 
61  /// This class provides all of the information necessary to convert a type
62  /// signature.
64  public:
65  SignatureConversion(unsigned numOrigInputs)
66  : remappedInputs(numOrigInputs) {}
67 
68  /// This struct represents a range of new types or a range of values that
69  /// remaps an existing signature input.
70  struct InputMapping {
71  size_t inputNo, size;
73 
74  /// Return "true" if this input was replaces with one or multiple values.
75  bool replacedWithValues() const { return !replacementValues.empty(); }
76  };
77 
78  /// Return the argument types for the new signature.
79  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
80 
81  /// Get the input mapping for the given argument.
82  std::optional<InputMapping> getInputMapping(unsigned input) const {
83  return remappedInputs[input];
84  }
85 
86  //===------------------------------------------------------------------===//
87  // Conversion Hooks
88  //===------------------------------------------------------------------===//
89 
90  /// Remap an input of the original signature with a new set of types. The
91  /// new types are appended to the new signature conversion.
92  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
93 
94  /// Append new input types to the signature conversion, this should only be
95  /// used if the new types are not intended to remap an existing input.
96  void addInputs(ArrayRef<Type> types);
97 
98  /// Remap an input of the original signature to `replacements`
99  /// values. This drops the original argument.
100  void remapInput(unsigned origInputNo, ArrayRef<Value> replacements);
101 
102  private:
103  /// Remap an input of the original signature with a range of types in the
104  /// new signature.
105  void remapInput(unsigned origInputNo, unsigned newInputNo,
106  unsigned newInputCount = 1);
107 
108  /// The remapping information for each of the original arguments.
109  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
110 
111  /// The set of new argument types.
112  SmallVector<Type, 4> argTypes;
113  };
114 
115  /// The general result of a type attribute conversion callback, allowing
116  /// for early termination. The default constructor creates the na case.
118  public:
119  constexpr AttributeConversionResult() : impl() {}
120  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
121 
123  static AttributeConversionResult na();
125 
126  bool hasResult() const;
127  bool isNa() const;
128  bool isAbort() const;
129 
130  Attribute getResult() const;
131 
132  private:
133  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
134 
135  llvm::PointerIntPair<Attribute, 2> impl;
136  // Note that na is 0 so that we can use PointerIntPair's default
137  // constructor.
138  static constexpr unsigned naTag = 0;
139  static constexpr unsigned resultTag = 1;
140  static constexpr unsigned abortTag = 2;
141  };
142 
143  /// Register a conversion function. A conversion function must be convertible
144  /// to any of the following forms (where `T` is a class derived from `Type`):
145  ///
146  /// * std::optional<Type>(T)
147  /// - This form represents a 1-1 type conversion. It should return nullptr
148  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
149  /// the converter is allowed to try another conversion function to
150  /// perform the conversion.
151  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
152  /// - This form represents a 1-N type conversion. It should return
153  /// `failure` or `std::nullopt` to signify a failed conversion. If the
154  /// new set of types is empty, the type is removed and any usages of the
155  /// existing value are expected to be removed during conversion. If
156  /// `std::nullopt` is returned, the converter is allowed to try another
157  /// conversion function to perform the conversion.
158  ///
159  /// Note: When attempting to convert a type, e.g. via 'convertType', the
160  /// mostly recently added conversions will be invoked first.
161  template <typename FnT, typename T = typename llvm::function_traits<
162  std::decay_t<FnT>>::template arg_t<0>>
163  void addConversion(FnT &&callback) {
164  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
165  }
166 
167  /// All of the following materializations require function objects that are
168  /// convertible to the following form:
169  /// `Value(OpBuilder &, T, ValueRange, Location)`,
170  /// where `T` is any subclass of `Type`. This function is responsible for
171  /// creating an operation, using the OpBuilder and Location provided, that
172  /// "casts" a range of values into a single value of the given type `T`. It
173  /// must return a Value of the type `T` on success and `nullptr` if
174  /// it failed but other materialization should be attempted. Materialization
175  /// functions must be provided when a type conversion may persist after the
176  /// conversion has finished.
177  ///
178  /// Note: Target materializations may optionally accept an additional Type
179  /// parameter, which is the original type of the SSA value. Furthermore, `T`
180  /// can be a TypeRange; in that case, the function must return a
181  /// SmallVector<Value>.
182 
183  /// This method registers a materialization that will be called when
184  /// converting (potentially multiple) block arguments that were the result of
185  /// a signature conversion of a single block argument, to a single SSA value
186  /// with the old block argument type.
187  ///
188  /// Note: Argument materializations are used only with the 1:N dialect
189  /// conversion driver. The 1:N dialect conversion driver will be removed soon
190  /// and so will be argument materializations.
191  template <typename FnT, typename T = typename llvm::function_traits<
192  std::decay_t<FnT>>::template arg_t<1>>
193  void addArgumentMaterialization(FnT &&callback) {
194  argumentMaterializations.emplace_back(
195  wrapMaterialization<T>(std::forward<FnT>(callback)));
196  }
197 
198  /// This method registers a materialization that will be called when
199  /// converting a replacement value back to its original source type.
200  /// This is used when some uses of the original value persist beyond the main
201  /// conversion.
202  template <typename FnT, typename T = typename llvm::function_traits<
203  std::decay_t<FnT>>::template arg_t<1>>
204  void addSourceMaterialization(FnT &&callback) {
205  sourceMaterializations.emplace_back(
206  wrapMaterialization<T>(std::forward<FnT>(callback)));
207  }
208 
209  /// This method registers a materialization that will be called when
210  /// converting a value to a target type according to a pattern's type
211  /// converter.
212  ///
213  /// Note: Target materializations can optionally inspect the "original"
214  /// type. This type may be different from the type of the input value.
215  /// For example, let's assume that a conversion pattern "P1" replaced an SSA
216  /// value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion
217  /// pattern "P2" matches an op that has "v1" as an operand. Let's furthermore
218  /// assume that "P2" determines that the converted target type of "t1" is
219  /// "t3", which may be different from "t2". In this example, the target
220  /// materialization will be invoked with: outputType = "t3", inputs = "v2",
221  /// originalType = "t1". Note that the original type "t1" cannot be recovered
222  /// from just "t3" and "v2"; that's why the originalType parameter exists.
223  ///
224  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
225  /// that case the materialization produces a SmallVector<Value>.
226  template <typename FnT, typename T = typename llvm::function_traits<
227  std::decay_t<FnT>>::template arg_t<1>>
228  void addTargetMaterialization(FnT &&callback) {
229  targetMaterializations.emplace_back(
230  wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
231  }
232 
233  /// Register a conversion function for attributes within types. Type
234  /// converters may call this function in order to allow hoking into the
235  /// translation of attributes that exist within types. For example, a type
236  /// converter for the `memref` type could use these conversions to convert
237  /// memory spaces or layouts in an extensible way.
238  ///
239  /// The conversion functions take a non-null Type or subclass of Type and a
240  /// non-null Attribute (or subclass of Attribute), and returns a
241  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
242  /// which may be `nullptr`, representing the conversion's success,
243  /// `AttributeConversionResult::na()` (the default empty value), indicating
244  /// that the conversion function did not apply and that further conversion
245  /// functions should be checked, or `AttributeConversionResult::abort()`
246  /// indicating that the conversion process should be aborted.
247  ///
248  /// Registered conversion functions are callled in the reverse of the order in
249  /// which they were registered.
250  template <
251  typename FnT,
252  typename T =
253  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
254  typename A =
255  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
256  void addTypeAttributeConversion(FnT &&callback) {
257  registerTypeAttributeConversion(
258  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
259  }
260 
261  /// Convert the given type. This function should return failure if no valid
262  /// conversion exists, success otherwise. If the new set of types is empty,
263  /// the type is removed and any usages of the existing value are expected to
264  /// be removed during conversion.
265  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
266 
267  /// This hook simplifies defining 1-1 type conversions. This function returns
268  /// the type to convert to on success, and a null type on failure.
269  Type convertType(Type t) const;
270 
271  /// Attempts a 1-1 type conversion, expecting the result type to be
272  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
273  /// and a null type on conversion or cast failure.
274  template <typename TargetType>
275  TargetType convertType(Type t) const {
276  return dyn_cast_or_null<TargetType>(convertType(t));
277  }
278 
279  /// Convert the given set of types, filling 'results' as necessary. This
280  /// returns failure if the conversion of any of the types fails, success
281  /// otherwise.
282  LogicalResult convertTypes(TypeRange types,
283  SmallVectorImpl<Type> &results) const;
284 
285  /// Return true if the given type is legal for this type converter, i.e. the
286  /// type converts to itself.
287  bool isLegal(Type type) const;
288 
289  /// Return true if all of the given types are legal for this type converter.
290  template <typename RangeT>
291  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
292  !std::is_convertible<RangeT, Operation *>::value,
293  bool>
294  isLegal(RangeT &&range) const {
295  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
296  }
297  /// Return true if the given operation has legal operand and result types.
298  bool isLegal(Operation *op) const;
299 
300  /// Return true if the types of block arguments within the region are legal.
301  bool isLegal(Region *region) const;
302 
303  /// Return true if the inputs and outputs of the given function type are
304  /// legal.
305  bool isSignatureLegal(FunctionType ty) const;
306 
307  /// This method allows for converting a specific argument of a signature. It
308  /// takes as inputs the original argument input number, type.
309  /// On success, it populates 'result' with any new mappings.
310  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
311  SignatureConversion &result) const;
312  LogicalResult convertSignatureArgs(TypeRange types,
313  SignatureConversion &result,
314  unsigned origInputOffset = 0) const;
315 
316  /// This function converts the type signature of the given block, by invoking
317  /// 'convertSignatureArg' for each argument. This function should return a
318  /// valid conversion for the signature on success, std::nullopt otherwise.
319  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
320 
321  /// Materialize a conversion from a set of types into one result type by
322  /// generating a cast sequence of some kind. See the respective
323  /// `add*Materialization` for more information on the context for these
324  /// methods.
326  Type resultType, ValueRange inputs) const;
328  Type resultType, ValueRange inputs) const;
330  Type resultType, ValueRange inputs,
331  Type originalType = {}) const;
332  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
333  Location loc,
334  TypeRange resultType,
335  ValueRange inputs,
336  Type originalType = {}) const;
337 
338  /// Convert an attribute present `attr` from within the type `type` using
339  /// the registered conversion functions. If no applicable conversion has been
340  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
341  /// is a valid return value for this function.
342  std::optional<Attribute> convertTypeAttribute(Type type,
343  Attribute attr) const;
344 
345 private:
346  /// The signature of the callback used to convert a type. If the new set of
347  /// types is empty, the type is removed and any usages of the existing value
348  /// are expected to be removed during conversion.
349  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
350  Type, SmallVectorImpl<Type> &)>;
351 
352  /// The signature of the callback used to materialize a source/argument
353  /// conversion.
354  ///
355  /// Arguments: builder, result type, inputs, location
356  using MaterializationCallbackFn =
357  std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
358 
359  /// The signature of the callback used to materialize a target conversion.
360  ///
361  /// Arguments: builder, result types, inputs, location, original type
362  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
363  OpBuilder &, TypeRange, ValueRange, Location, Type)>;
364 
365  /// The signature of the callback used to convert a type attribute.
366  using TypeAttributeConversionCallbackFn =
367  std::function<AttributeConversionResult(Type, Attribute)>;
368 
369  /// Generate a wrapper for the given callback. This allows for accepting
370  /// different callback forms, that all compose into a single version.
371  /// With callback of form: `std::optional<Type>(T)`
372  template <typename T, typename FnT>
373  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
374  wrapCallback(FnT &&callback) const {
375  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
376  T type, SmallVectorImpl<Type> &results) {
377  if (std::optional<Type> resultOpt = callback(type)) {
378  bool wasSuccess = static_cast<bool>(*resultOpt);
379  if (wasSuccess)
380  results.push_back(*resultOpt);
381  return std::optional<LogicalResult>(success(wasSuccess));
382  }
383  return std::optional<LogicalResult>();
384  });
385  }
386  /// With callback of form: `std::optional<LogicalResult>(
387  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
388  template <typename T, typename FnT>
389  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
390  ConversionCallbackFn>
391  wrapCallback(FnT &&callback) const {
392  return [callback = std::forward<FnT>(callback)](
393  Type type,
394  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
395  T derivedType = dyn_cast<T>(type);
396  if (!derivedType)
397  return std::nullopt;
398  return callback(derivedType, results);
399  };
400  }
401 
402  /// Register a type conversion.
403  void registerConversion(ConversionCallbackFn callback) {
404  conversions.emplace_back(std::move(callback));
405  cachedDirectConversions.clear();
406  cachedMultiConversions.clear();
407  }
408 
409  /// Generate a wrapper for the given argument/source materialization
410  /// callback. The callback may take any subclass of `Type` and the
411  /// wrapper will check for the target type to be of the expected class
412  /// before calling the callback.
413  template <typename T, typename FnT>
414  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
415  return [callback = std::forward<FnT>(callback)](
416  OpBuilder &builder, Type resultType, ValueRange inputs,
417  Location loc) -> Value {
418  if (T derivedType = dyn_cast<T>(resultType))
419  return callback(builder, derivedType, inputs, loc);
420  return Value();
421  };
422  }
423 
424  /// Generate a wrapper for the given target materialization callback.
425  /// The callback may take any subclass of `Type` and the wrapper will check
426  /// for the target type to be of the expected class before calling the
427  /// callback.
428  ///
429  /// With callback of form:
430  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
431  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
432  template <typename T, typename FnT>
433  std::enable_if_t<
434  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
435  TargetMaterializationCallbackFn>
436  wrapTargetMaterialization(FnT &&callback) const {
437  return [callback = std::forward<FnT>(callback)](
438  OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
439  Location loc, Type originalType) -> SmallVector<Value> {
440  SmallVector<Value> result;
441  if constexpr (std::is_same<T, TypeRange>::value) {
442  // This is a 1:N target materialization. Return the produces values
443  // directly.
444  result = callback(builder, resultTypes, inputs, loc, originalType);
445  } else if constexpr (std::is_assignable<Type, T>::value) {
446  // This is a 1:1 target materialization. Invoke the callback only if a
447  // single SSA value is requested.
448  if (resultTypes.size() == 1) {
449  // Invoke the callback only if the type class of the callback matches
450  // the requested result type.
451  if (T derivedType = dyn_cast<T>(resultTypes.front())) {
452  // 1:1 materializations produce single values, but we store 1:N
453  // target materialization functions in the type converter. Wrap the
454  // result value in a SmallVector<Value>.
455  Value val =
456  callback(builder, derivedType, inputs, loc, originalType);
457  if (val)
458  result.push_back(val);
459  }
460  }
461  } else {
462  static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
463  }
464  return result;
465  };
466  }
467  /// With callback of form:
468  /// - Value(OpBuilder &, T, ValueRange, Location)
469  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
470  template <typename T, typename FnT>
471  std::enable_if_t<
472  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
473  TargetMaterializationCallbackFn>
474  wrapTargetMaterialization(FnT &&callback) const {
475  return wrapTargetMaterialization<T>(
476  [callback = std::forward<FnT>(callback)](
477  OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
478  Type originalType) {
479  return callback(builder, resultTypes, inputs, loc);
480  });
481  }
482 
483  /// Generate a wrapper for the given memory space conversion callback. The
484  /// callback may take any subclass of `Attribute` and the wrapper will check
485  /// for the target attribute to be of the expected class before calling the
486  /// callback.
487  template <typename T, typename A, typename FnT>
488  TypeAttributeConversionCallbackFn
489  wrapTypeAttributeConversion(FnT &&callback) const {
490  return [callback = std::forward<FnT>(callback)](
491  Type type, Attribute attr) -> AttributeConversionResult {
492  if (T derivedType = dyn_cast<T>(type)) {
493  if (A derivedAttr = dyn_cast_or_null<A>(attr))
494  return callback(derivedType, derivedAttr);
495  }
497  };
498  }
499 
500  /// Register a memory space conversion, clearing caches.
501  void
502  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
503  typeAttributeConversions.emplace_back(std::move(callback));
504  // Clear type conversions in case a memory space is lingering inside.
505  cachedDirectConversions.clear();
506  cachedMultiConversions.clear();
507  }
508 
509  /// The set of registered conversion functions.
510  SmallVector<ConversionCallbackFn, 4> conversions;
511 
512  /// The list of registered materialization functions.
513  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
514  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
515  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
516 
517  /// The list of registered type attribute conversion functions.
518  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
519 
520  /// A set of cached conversions to avoid recomputing in the common case.
521  /// Direct 1-1 conversions are the most common, so this cache stores the
522  /// successful 1-1 conversions as well as all failed conversions.
523  mutable DenseMap<Type, Type> cachedDirectConversions;
524  /// This cache stores the successful 1->N conversions, where N != 1.
525  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
526  /// A mutex used for cache access
527  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
528 };
529 
530 //===----------------------------------------------------------------------===//
531 // Conversion Patterns
532 //===----------------------------------------------------------------------===//
533 
534 namespace detail {
535 /// Helper class that derives from a ConversionRewritePattern class and
536 /// provides separate `match` and `rewrite` entry points instead of a combined
537 /// `matchAndRewrite`.
538 template <typename PatternT>
540  using PatternT::PatternT;
541 
542  /// Attempt to match against IR rooted at the specified operation, which is
543  /// the same operation kind as getRootKind().
544  ///
545  /// Note: This function must not modify the IR.
546  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
547 
548  /// Rewrite the IR rooted at the specified operation with the result of
549  /// this pattern, generating any new operations with the specified
550  /// rewriter.
551  virtual void rewrite(typename PatternT::OperationT op,
552  typename PatternT::OpAdaptor adaptor,
553  ConversionPatternRewriter &rewriter) const {
554  // One of the two `rewrite` functions must be implemented.
555  llvm_unreachable("rewrite is not implemented");
556  }
557 
558  virtual void rewrite(typename PatternT::OperationT op,
559  typename PatternT::OneToNOpAdaptor adaptor,
560  ConversionPatternRewriter &rewriter) const {
561  if constexpr (std::is_same<typename PatternT::OpAdaptor,
562  ArrayRef<Value>>::value) {
563  rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
564  } else {
565  SmallVector<Value> oneToOneOperands =
566  PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
567  rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
568  rewriter);
569  }
570  }
571 
572  LogicalResult
573  matchAndRewrite(typename PatternT::OperationT op,
574  typename PatternT::OneToNOpAdaptor adaptor,
575  ConversionPatternRewriter &rewriter) const final {
576  if (succeeded(match(op))) {
577  rewrite(op, adaptor, rewriter);
578  return success();
579  }
580  return failure();
581  }
582 
583  LogicalResult
584  matchAndRewrite(typename PatternT::OperationT op,
585  typename PatternT::OpAdaptor adaptor,
586  ConversionPatternRewriter &rewriter) const final {
587  // Users would normally override this function in conversion patterns to
588  // implement a 1:1 pattern. Patterns that are derived from this class have
589  // separate `match` and `rewrite` functions, so this `matchAndRewrite`
590  // overload is obsolete.
591  llvm_unreachable("this function is unreachable");
592  }
593 };
594 } // namespace detail
595 
596 /// Base class for the conversion patterns. This pattern class enables type
597 /// conversions, and other uses specific to the conversion framework. As such,
598 /// patterns of this type can only be used with the 'apply*' methods below.
600 public:
604 
605  /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
606  /// separate `match` and `rewrite`.
609 
610  /// Hook for derived classes to implement combined matching and rewriting.
611  /// This overload supports only 1:1 replacements. The 1:N overload is called
612  /// by the driver. By default, it calls this 1:1 overload or reports a fatal
613  /// error if 1:N replacements were found.
614  virtual LogicalResult
616  ConversionPatternRewriter &rewriter) const {
617  llvm_unreachable("matchAndRewrite is not implemented");
618  }
619 
620  /// Hook for derived classes to implement combined matching and rewriting.
621  /// This overload supports 1:N replacements.
622  virtual LogicalResult
624  ConversionPatternRewriter &rewriter) const {
625  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
626  }
627 
628  /// Attempt to match and rewrite the IR root at the specified operation.
629  LogicalResult matchAndRewrite(Operation *op,
630  PatternRewriter &rewriter) const final;
631 
632  /// Return the type converter held by this pattern, or nullptr if the pattern
633  /// does not require type conversion.
634  const TypeConverter *getTypeConverter() const { return typeConverter; }
635 
636  template <typename ConverterTy>
637  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
638  const ConverterTy *>
640  return static_cast<const ConverterTy *>(typeConverter);
641  }
642 
643 protected:
644  /// See `RewritePattern::RewritePattern` for information on the other
645  /// available constructors.
646  using RewritePattern::RewritePattern;
647  /// Construct a conversion pattern with the given converter, and forward the
648  /// remaining arguments to RewritePattern.
649  template <typename... Args>
651  : RewritePattern(std::forward<Args>(args)...),
653 
654  /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
655  /// try to extract the single value of each range to construct a the inputs
656  /// for a 1:1 adaptor.
657  ///
658  /// This function produces a fatal error if at least one range has 0 or
659  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
662 
663 protected:
664  /// An optional type converter for use by this pattern.
665  const TypeConverter *typeConverter = nullptr;
666 };
667 
668 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
669 /// matching and rewriting against an instance of a derived operation class as
670 /// opposed to a raw Operation.
671 template <typename SourceOp>
673 public:
674  using OperationT = SourceOp;
675  using OpAdaptor = typename SourceOp::Adaptor;
677  typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
678 
679  /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
680  /// separate `match` and `rewrite`.
683 
685  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
687  PatternBenefit benefit = 1)
688  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
689  context) {}
690 
691  /// Wrappers around the ConversionPattern methods that pass the derived op
692  /// type.
693  LogicalResult
695  ConversionPatternRewriter &rewriter) const final {
696  auto sourceOp = cast<SourceOp>(op);
697  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
698  }
699  LogicalResult
701  ConversionPatternRewriter &rewriter) const final {
702  auto sourceOp = cast<SourceOp>(op);
703  return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
704  rewriter);
705  }
706 
707  /// Methods that operate on the SourceOp type. One of these must be
708  /// overridden by the derived pattern class.
709  virtual LogicalResult
710  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
711  ConversionPatternRewriter &rewriter) const {
712  llvm_unreachable("matchAndRewrite is not implemented");
713  }
714  virtual LogicalResult
715  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
716  ConversionPatternRewriter &rewriter) const {
717  SmallVector<Value> oneToOneOperands =
718  getOneToOneAdaptorOperands(adaptor.getOperands());
719  return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
720  }
721 
722 private:
724 };
725 
726 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
727 /// allows for matching and rewriting against an instance of an OpInterface
728 /// class as opposed to a raw Operation.
729 template <typename SourceOp>
731 public:
734  SourceOp::getInterfaceID(), benefit, context) {}
736  MLIRContext *context, PatternBenefit benefit = 1)
738  SourceOp::getInterfaceID(), benefit, context) {}
739 
740  /// Wrappers around the ConversionPattern methods that pass the derived op
741  /// type.
742  LogicalResult
744  ConversionPatternRewriter &rewriter) const final {
745  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
746  }
747  LogicalResult
749  ConversionPatternRewriter &rewriter) const final {
750  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
751  }
752 
753  /// Methods that operate on the SourceOp type. One of these must be
754  /// overridden by the derived pattern class.
755  virtual LogicalResult
756  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
757  ConversionPatternRewriter &rewriter) const {
758  llvm_unreachable("matchAndRewrite is not implemented");
759  }
760  virtual LogicalResult
761  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
762  ConversionPatternRewriter &rewriter) const {
763  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
764  }
765 
766 private:
768 };
769 
770 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
771 /// for matching and rewriting against instances of an operation that possess a
772 /// given trait.
773 template <template <typename> class TraitType>
775 public:
778  TypeID::get<TraitType>(), benefit, context) {}
780  MLIRContext *context, PatternBenefit benefit = 1)
782  TypeID::get<TraitType>(), benefit, context) {}
783 };
784 
785 /// Generic utility to convert op result types according to type converter
786 /// without knowing exact op type.
787 /// Clones existing op with new result types and returns it.
788 FailureOr<Operation *>
789 convertOpResultTypes(Operation *op, ValueRange operands,
790  const TypeConverter &converter,
791  ConversionPatternRewriter &rewriter);
792 
793 /// Add a pattern to the given pattern list to convert the signature of a
794 /// FunctionOpInterface op with the given type converter. This only supports
795 /// ops which use FunctionType to represent their type.
797  StringRef functionLikeOpName, RewritePatternSet &patterns,
798  const TypeConverter &converter);
799 
800 template <typename FuncOpT>
802  RewritePatternSet &patterns, const TypeConverter &converter) {
803  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
804  patterns, converter);
805 }
806 
808  RewritePatternSet &patterns, const TypeConverter &converter);
809 
810 //===----------------------------------------------------------------------===//
811 // Conversion PatternRewriter
812 //===----------------------------------------------------------------------===//
813 
814 namespace detail {
815 struct ConversionPatternRewriterImpl;
816 } // namespace detail
817 
818 /// This class implements a pattern rewriter for use with ConversionPatterns. It
819 /// extends the base PatternRewriter and provides special conversion specific
820 /// hooks.
822 public:
824 
825  /// Apply a signature conversion to given block. This replaces the block with
826  /// a new block containing the updated signature. The operations of the given
827  /// block are inlined into the newly-created block, which is returned.
828  ///
829  /// If no block argument types are changing, the original block will be
830  /// left in place and returned.
831  ///
832  /// A signature converison must be provided. (Type converters can construct
833  /// a signature conversion with `convertBlockSignature`.)
834  ///
835  /// Optionally, a type converter can be provided to build materializations.
836  /// Note: If no type converter was provided or the type converter does not
837  /// specify any suitable argument/target materialization rules, the dialect
838  /// conversion may fail to legalize unresolved materializations.
839  Block *
842  const TypeConverter *converter = nullptr);
843 
844  /// Apply a signature conversion to each block in the given region. This
845  /// replaces each block with a new block containing the updated signature. If
846  /// an updated signature would match the current signature, the respective
847  /// block is left in place as is. (See `applySignatureConversion` for
848  /// details.) The new entry block of the region is returned.
849  ///
850  /// SignatureConversions are computed with the specified type converter.
851  /// This function returns "failure" if the type converter failed to compute
852  /// a SignatureConversion for at least one block.
853  ///
854  /// Optionally, a special SignatureConversion can be specified for the entry
855  /// block. This is because the types of the entry block arguments are often
856  /// tied semantically to the operation.
857  FailureOr<Block *> convertRegionTypes(
858  Region *region, const TypeConverter &converter,
859  TypeConverter::SignatureConversion *entryConversion = nullptr);
860 
861  /// Replace all the uses of the block argument `from` with value `to`.
863 
864  /// Return the converted value of 'key' with a type defined by the type
865  /// converter of the currently executing pattern. Return nullptr in the case
866  /// of failure, the remapped value otherwise.
868 
869  /// Return the converted values that replace 'keys' with types defined by the
870  /// type converter of the currently executing pattern. Returns failure if the
871  /// remap failed, success otherwise.
872  LogicalResult getRemappedValues(ValueRange keys,
873  SmallVectorImpl<Value> &results);
874 
875  //===--------------------------------------------------------------------===//
876  // PatternRewriter Hooks
877  //===--------------------------------------------------------------------===//
878 
879  /// Indicate that the conversion rewriter can recover from rewrite failure.
880  /// Recovery is supported via rollback, allowing for continued processing of
881  /// patterns even if a failure is encountered during the rewrite step.
882  bool canRecoverFromRewriteFailure() const override { return true; }
883 
884  /// Replace the given operation with the new values. The number of op results
885  /// and replacement values must match. The types may differ: the dialect
886  /// conversion driver will reconcile any surviving type mismatches at the end
887  /// of the conversion process with source materializations. The given
888  /// operation is erased.
889  void replaceOp(Operation *op, ValueRange newValues) override;
890 
891  /// Replace the given operation with the results of the new op. The number of
892  /// op results must match. The types may differ: the dialect conversion
893  /// driver will reconcile any surviving type mismatches at the end of the
894  /// conversion process with source materializations. The original operation
895  /// is erased.
896  void replaceOp(Operation *op, Operation *newOp) override;
897 
898  /// Replace the given operation with the new value ranges. The number of op
899  /// results and value ranges must match. The given operation is erased.
901  SmallVector<SmallVector<Value>> &&newValues);
902  template <typename RangeT = ValueRange>
905  llvm::to_vector_of<SmallVector<Value>>(newValues));
906  }
907  template <typename RangeT>
908  void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
910  ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
911  }
912 
913  /// PatternRewriter hook for erasing a dead operation. The uses of this
914  /// operation *must* be made dead by the end of the conversion process,
915  /// otherwise an assert will be issued.
916  void eraseOp(Operation *op) override;
917 
918  /// PatternRewriter hook for erase all operations in a block. This is not yet
919  /// implemented for dialect conversion.
920  void eraseBlock(Block *block) override;
921 
922  /// PatternRewriter hook for inlining the ops of a block into another block.
923  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
924  ValueRange argValues = std::nullopt) override;
926 
927  /// PatternRewriter hook for updating the given operation in-place.
928  /// Note: These methods only track updates to the given operation itself,
929  /// and not nested regions. Updates to regions will still require notification
930  /// through other more specific hooks above.
931  void startOpModification(Operation *op) override;
932 
933  /// PatternRewriter hook for updating the given operation in-place.
934  void finalizeOpModification(Operation *op) override;
935 
936  /// PatternRewriter hook for updating the given operation in-place.
937  void cancelOpModification(Operation *op) override;
938 
939  /// Return a reference to the internal implementation.
941 
942 private:
943  // Allow OperationConverter to construct new rewriters.
944  friend struct OperationConverter;
945 
946  /// Conversion pattern rewriters must not be used outside of dialect
947  /// conversions. They apply some IR rewrites in a delayed fashion and could
948  /// bring the IR into an inconsistent state when used standalone.
950  const ConversionConfig &config);
951 
952  // Hide unsupported pattern rewriter API.
954 
955  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
956 };
957 
958 //===----------------------------------------------------------------------===//
959 // ConversionTarget
960 //===----------------------------------------------------------------------===//
961 
962 /// This class describes a specific conversion target.
964 public:
965  /// This enumeration corresponds to the specific action to take when
966  /// considering an operation legal for this conversion target.
967  enum class LegalizationAction {
968  /// The target supports this operation.
969  Legal,
970 
971  /// This operation has dynamic legalization constraints that must be checked
972  /// by the target.
973  Dynamic,
974 
975  /// The target explicitly does not support this operation.
976  Illegal,
977  };
978 
979  /// A structure containing additional information describing a specific legal
980  /// operation instance.
981  struct LegalOpDetails {
982  /// A flag that indicates if this operation is 'recursively' legal. This
983  /// means that if an operation is legal, either statically or dynamically,
984  /// all of the operations nested within are also considered legal.
985  bool isRecursivelyLegal = false;
986  };
987 
988  /// The signature of the callback used to determine if an operation is
989  /// dynamically legal on the target.
991  std::function<std::optional<bool>(Operation *)>;
992 
993  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
994  virtual ~ConversionTarget() = default;
995 
996  //===--------------------------------------------------------------------===//
997  // Legality Registration
998  //===--------------------------------------------------------------------===//
999 
1000  /// Register a legality action for the given operation.
1001  void setOpAction(OperationName op, LegalizationAction action);
1002  template <typename OpT>
1004  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
1005  }
1006 
1007  /// Register the given operations as legal.
1010  }
1011  template <typename OpT>
1012  void addLegalOp() {
1013  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
1014  }
1015  template <typename OpT, typename OpT2, typename... OpTs>
1016  void addLegalOp() {
1017  addLegalOp<OpT>();
1018  addLegalOp<OpT2, OpTs...>();
1019  }
1020 
1021  /// Register the given operation as dynamically legal and set the dynamic
1022  /// legalization callback to the one provided.
1024  const DynamicLegalityCallbackFn &callback) {
1026  setLegalityCallback(op, callback);
1027  }
1028  template <typename OpT>
1030  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
1031  callback);
1032  }
1033  template <typename OpT, typename OpT2, typename... OpTs>
1035  addDynamicallyLegalOp<OpT>(callback);
1036  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
1037  }
1038  template <typename OpT, class Callable>
1039  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1040  addDynamicallyLegalOp(Callable &&callback) {
1041  addDynamicallyLegalOp<OpT>(
1042  [=](Operation *op) { return callback(cast<OpT>(op)); });
1043  }
1044 
1045  /// Register the given operation as illegal, i.e. this operation is known to
1046  /// not be supported by this target.
1049  }
1050  template <typename OpT>
1051  void addIllegalOp() {
1052  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
1053  }
1054  template <typename OpT, typename OpT2, typename... OpTs>
1055  void addIllegalOp() {
1056  addIllegalOp<OpT>();
1057  addIllegalOp<OpT2, OpTs...>();
1058  }
1059 
1060  /// Mark an operation, that *must* have either been set as `Legal` or
1061  /// `DynamicallyLegal`, as being recursively legal. This means that in
1062  /// addition to the operation itself, all of the operations nested within are
1063  /// also considered legal. An optional dynamic legality callback may be
1064  /// provided to mark subsets of legal instances as recursively legal.
1066  const DynamicLegalityCallbackFn &callback);
1067  template <typename OpT>
1069  OperationName opName(OpT::getOperationName(), &ctx);
1070  markOpRecursivelyLegal(opName, callback);
1071  }
1072  template <typename OpT, typename OpT2, typename... OpTs>
1074  markOpRecursivelyLegal<OpT>(callback);
1075  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
1076  }
1077  template <typename OpT, class Callable>
1078  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1079  markOpRecursivelyLegal(Callable &&callback) {
1080  markOpRecursivelyLegal<OpT>(
1081  [=](Operation *op) { return callback(cast<OpT>(op)); });
1082  }
1083 
1084  /// Register a legality action for the given dialects.
1085  void setDialectAction(ArrayRef<StringRef> dialectNames,
1086  LegalizationAction action);
1087 
1088  /// Register the operations of the given dialects as legal.
1089  template <typename... Names>
1090  void addLegalDialect(StringRef name, Names... names) {
1091  SmallVector<StringRef, 2> dialectNames({name, names...});
1093  }
1094  template <typename... Args>
1096  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1098  }
1099 
1100  /// Register the operations of the given dialects as dynamically legal, i.e.
1101  /// requiring custom handling by the callback.
1102  template <typename... Names>
1104  StringRef name, Names... names) {
1105  SmallVector<StringRef, 2> dialectNames({name, names...});
1107  setLegalityCallback(dialectNames, callback);
1108  }
1109  template <typename... Args>
1111  addDynamicallyLegalDialect(std::move(callback),
1112  Args::getDialectNamespace()...);
1113  }
1114 
1115  /// Register unknown operations as dynamically legal. For operations(and
1116  /// dialects) that do not have a set legalization action, treat them as
1117  /// dynamically legal and invoke the given callback.
1119  setLegalityCallback(fn);
1120  }
1121 
1122  /// Register the operations of the given dialects as illegal, i.e.
1123  /// operations of this dialect are not supported by the target.
1124  template <typename... Names>
1125  void addIllegalDialect(StringRef name, Names... names) {
1126  SmallVector<StringRef, 2> dialectNames({name, names...});
1128  }
1129  template <typename... Args>
1131  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1133  }
1134 
1135  //===--------------------------------------------------------------------===//
1136  // Legality Querying
1137  //===--------------------------------------------------------------------===//
1138 
1139  /// Get the legality action for the given operation.
1140  std::optional<LegalizationAction> getOpAction(OperationName op) const;
1141 
1142  /// If the given operation instance is legal on this target, a structure
1143  /// containing legality information is returned. If the operation is not
1144  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1145  /// legality wasn't registered by user or dynamic legality callbacks returned
1146  /// None.
1147  ///
1148  /// Note: Legality is actually a 4-state: Legal(recursive=true),
1149  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1150  /// either as Legal or Illegal depending on context.
1151  std::optional<LegalOpDetails> isLegal(Operation *op) const;
1152 
1153  /// Returns true is operation instance is illegal on this target. Returns
1154  /// false if operation is legal, operation legality wasn't registered by user
1155  /// or dynamic legality callbacks returned None.
1156  bool isIllegal(Operation *op) const;
1157 
1158 private:
1159  /// Set the dynamic legality callback for the given operation.
1160  void setLegalityCallback(OperationName name,
1161  const DynamicLegalityCallbackFn &callback);
1162 
1163  /// Set the dynamic legality callback for the given dialects.
1164  void setLegalityCallback(ArrayRef<StringRef> dialects,
1165  const DynamicLegalityCallbackFn &callback);
1166 
1167  /// Set the dynamic legality callback for the unknown ops.
1168  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1169 
1170  /// The set of information that configures the legalization of an operation.
1171  struct LegalizationInfo {
1172  /// The legality action this operation was given.
1174 
1175  /// If some legal instances of this operation may also be recursively legal.
1176  bool isRecursivelyLegal = false;
1177 
1178  /// The legality callback if this operation is dynamically legal.
1179  DynamicLegalityCallbackFn legalityFn;
1180  };
1181 
1182  /// Get the legalization information for the given operation.
1183  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1184 
1185  /// A deterministic mapping of operation name and its respective legality
1186  /// information.
1187  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1188 
1189  /// A set of legality callbacks for given operation names that are used to
1190  /// check if an operation instance is recursively legal.
1191  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1192 
1193  /// A deterministic mapping of dialect name to the specific legality action to
1194  /// take.
1195  llvm::StringMap<LegalizationAction> legalDialects;
1196 
1197  /// A set of dynamic legality callbacks for given dialect names.
1198  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1199 
1200  /// An optional legality callback for unknown operations.
1201  DynamicLegalityCallbackFn unknownLegalityFn;
1202 
1203  /// The current context this target applies to.
1204  MLIRContext &ctx;
1205 };
1206 
1207 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1208 //===----------------------------------------------------------------------===//
1209 // PDL Configuration
1210 //===----------------------------------------------------------------------===//
1211 
1212 /// A PDL configuration that is used to supported dialect conversion
1213 /// functionality.
1215  : public PDLPatternConfigBase<PDLConversionConfig> {
1216 public:
1217  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1218  ~PDLConversionConfig() final = default;
1219 
1220  /// Return the type converter used by this configuration, which may be nullptr
1221  /// if no type conversions are expected.
1222  const TypeConverter *getTypeConverter() const { return converter; }
1223 
1224  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1225  /// pattern.
1226  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1227  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1228 
1229 private:
1230  /// An optional type converter to use for the pattern.
1231  const TypeConverter *converter;
1232 };
1233 
1234 /// Register the dialect conversion PDL functions with the given pattern set.
1235 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1236 
1237 #else
1238 
1239 // Stubs for when PDL in rewriting is not enabled.
1240 
1241 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1242 
1243 class PDLConversionConfig final {
1244 public:
1245  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1246 };
1247 
1248 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1249 
1250 //===----------------------------------------------------------------------===//
1251 // ConversionConfig
1252 //===----------------------------------------------------------------------===//
1253 
1254 /// Dialect conversion configuration.
1256  /// An optional callback used to notify about match failure diagnostics during
1257  /// the conversion. Diagnostics reported to this callback may only be
1258  /// available in debug mode.
1260 
1261  /// Partial conversion only. All operations that are found not to be
1262  /// legalizable are placed in this set. (Note that if there is an op
1263  /// explicitly marked as illegal, the conversion terminates and the set will
1264  /// not necessarily be complete.)
1266 
1267  /// Analysis conversion only. All operations that are found to be legalizable
1268  /// are placed in this set. Note that no actual rewrites are applied to the
1269  /// IR during an analysis conversion and only pre-existing operations are
1270  /// added to the set.
1272 
1273  /// An optional listener that is notified about all IR modifications in case
1274  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1275  /// modifications are visible (i.e., they were all rolled back), or if the
1276  /// dialect conversion is an "analysis conversion", no notifications are
1277  /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1278  ///
1279  /// Note: Notifications are sent in a delayed fashion, when the dialect
1280  /// conversion is guaranteed to succeed. At that point, some IR modifications
1281  /// may already have been materialized. Consequently, operations/blocks that
1282  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1283  /// guaranteed to be valid pointers and accessing op names is allowed. But
1284  /// there are no guarantees about the state of ops/blocks at the time that a
1285  /// callback is triggered.)
1286  ///
1287  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1288  /// and inserted, and later moved to another block. (Moving ops also triggers
1289  /// "notifyOperationInserted".)
1290  ///
1291  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1292  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1293  ///
1294  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1295  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1296  /// immediately performed by the dialect conversion (and rolled back upon
1297  /// failure).
1298  //
1299  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1300  // callback, the previous region/block is provided to the callback, but not
1301  // the iterator pointing to the exact location within the region/block. That
1302  // is because these notifications are sent with a delay (after the IR has
1303  // already been modified) and iterators into past IR state cannot be
1304  // represented at the moment.
1306 
1307  /// If set to "true", the dialect conversion attempts to build source/target
1308  /// materializations through the type converter API in lieu of
1309  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1310  /// at least one materialization could not be built.
1311  ///
1312  /// If set to "false", the dialect conversion does not build any custom
1313  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1314  /// ops to ensure that the resulting IR is valid.
1316 };
1317 
1318 //===----------------------------------------------------------------------===//
1319 // Reconcile Unrealized Casts
1320 //===----------------------------------------------------------------------===//
1321 
1322 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1323 /// left-over ops in `remainingCastOps` (if provided).
1324 ///
1325 /// This function processes cast ops in a worklist-driven fashion. For each
1326 /// cast op, if the chain of input casts eventually reaches a cast op where the
1327 /// input types match the output types of the matched op, replace the matched
1328 /// op with the inputs.
1329 ///
1330 /// Example:
1331 /// %1 = unrealized_conversion_cast %0 : !A to !B
1332 /// %2 = unrealized_conversion_cast %1 : !B to !C
1333 /// %3 = unrealized_conversion_cast %2 : !C to !A
1334 ///
1335 /// In the above example, %0 can be used instead of %3 and all cast ops are
1336 /// folded away.
1339  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1340 
1341 //===----------------------------------------------------------------------===//
1342 // Op Conversion Entry Points
1343 //===----------------------------------------------------------------------===//
1344 
1345 /// Below we define several entry points for operation conversion. It is
1346 /// important to note that the patterns provided to the conversion framework may
1347 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1348 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1349 /// the use of the PatternRewriter.
1350 
1351 /// Apply a partial conversion on the given operations and all nested
1352 /// operations. This method converts as many operations to the target as
1353 /// possible, ignoring operations that failed to legalize. This method only
1354 /// returns failure if there ops explicitly marked as illegal.
1355 LogicalResult
1357  const ConversionTarget &target,
1360 LogicalResult
1364 
1365 /// Apply a complete conversion on the given operations, and all nested
1366 /// operations. This method returns failure if the conversion of any operation
1367 /// fails, or if there are unreachable blocks in any of the regions nested
1368 /// within 'ops'.
1369 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1370  const ConversionTarget &target,
1373 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1376 
1377 /// Apply an analysis conversion on the given operations, and all nested
1378 /// operations. This method analyzes which operations would be successfully
1379 /// converted to the target if a conversion was applied. All operations that
1380 /// were found to be legalizable to the given 'target' are placed within the
1381 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1382 /// applied to the operations on success. This method only returns failure if
1383 /// there are unreachable blocks in any of the regions nested within 'ops'.
1384 LogicalResult
1388 LogicalResult
1392 } // namespace mlir
1393 
1394 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void replaceOpWithMultiple(Operation *op, ArrayRef< RangeT > newValues)
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, RangeT &&newValues)
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
typename SourceOp::Adaptor OpAdaptor
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:271
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Helper class that derives from a ConversionRewritePattern class and provides separate match and rewri...
Helper class that derives from a RewritePattern class and provides separate match and rewrite entry p...
Definition: PatternMatch.h:244
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
This struct represents a range of new types or a range of values that remaps an existing signature in...
bool replacedWithValues() const
Return "true" if this input was replaces with one or multiple values.