MLIR  20.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  argumentMaterializations(other.argumentMaterializations),
49  sourceMaterializations(other.sourceMaterializations),
50  targetMaterializations(other.targetMaterializations),
51  typeAttributeConversions(other.typeAttributeConversions) {}
53  conversions = other.conversions;
54  argumentMaterializations = other.argumentMaterializations;
55  sourceMaterializations = other.sourceMaterializations;
56  targetMaterializations = other.targetMaterializations;
57  typeAttributeConversions = other.typeAttributeConversions;
58  return *this;
59  }
60 
61  /// This class provides all of the information necessary to convert a type
62  /// signature.
64  public:
65  SignatureConversion(unsigned numOrigInputs)
66  : remappedInputs(numOrigInputs) {}
67 
68  /// This struct represents a range of new types or a single value that
69  /// remaps an existing signature input.
70  struct InputMapping {
71  size_t inputNo, size;
73  };
74 
75  /// Return the argument types for the new signature.
76  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
77 
78  /// Get the input mapping for the given argument.
79  std::optional<InputMapping> getInputMapping(unsigned input) const {
80  return remappedInputs[input];
81  }
82 
83  //===------------------------------------------------------------------===//
84  // Conversion Hooks
85  //===------------------------------------------------------------------===//
86 
87  /// Remap an input of the original signature with a new set of types. The
88  /// new types are appended to the new signature conversion.
89  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
90 
91  /// Append new input types to the signature conversion, this should only be
92  /// used if the new types are not intended to remap an existing input.
93  void addInputs(ArrayRef<Type> types);
94 
95  /// Remap an input of the original signature to another `replacement`
96  /// value. This drops the original argument.
97  void remapInput(unsigned origInputNo, Value replacement);
98 
99  private:
100  /// Remap an input of the original signature with a range of types in the
101  /// new signature.
102  void remapInput(unsigned origInputNo, unsigned newInputNo,
103  unsigned newInputCount = 1);
104 
105  /// The remapping information for each of the original arguments.
106  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
107 
108  /// The set of new argument types.
109  SmallVector<Type, 4> argTypes;
110  };
111 
112  /// The general result of a type attribute conversion callback, allowing
113  /// for early termination. The default constructor creates the na case.
115  public:
116  constexpr AttributeConversionResult() : impl() {}
117  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
118 
120  static AttributeConversionResult na();
122 
123  bool hasResult() const;
124  bool isNa() const;
125  bool isAbort() const;
126 
127  Attribute getResult() const;
128 
129  private:
130  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
131 
132  llvm::PointerIntPair<Attribute, 2> impl;
133  // Note that na is 0 so that we can use PointerIntPair's default
134  // constructor.
135  static constexpr unsigned naTag = 0;
136  static constexpr unsigned resultTag = 1;
137  static constexpr unsigned abortTag = 2;
138  };
139 
140  /// Register a conversion function. A conversion function must be convertible
141  /// to any of the following forms (where `T` is a class derived from `Type`):
142  ///
143  /// * std::optional<Type>(T)
144  /// - This form represents a 1-1 type conversion. It should return nullptr
145  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
146  /// the converter is allowed to try another conversion function to
147  /// perform the conversion.
148  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
149  /// - This form represents a 1-N type conversion. It should return
150  /// `failure` or `std::nullopt` to signify a failed conversion. If the
151  /// new set of types is empty, the type is removed and any usages of the
152  /// existing value are expected to be removed during conversion. If
153  /// `std::nullopt` is returned, the converter is allowed to try another
154  /// conversion function to perform the conversion.
155  ///
156  /// Note: When attempting to convert a type, e.g. via 'convertType', the
157  /// mostly recently added conversions will be invoked first.
158  template <typename FnT, typename T = typename llvm::function_traits<
159  std::decay_t<FnT>>::template arg_t<0>>
160  void addConversion(FnT &&callback) {
161  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
162  }
163 
164  /// All of the following materializations require function objects that are
165  /// convertible to the following form:
166  /// `Value(OpBuilder &, T, ValueRange, Location)`,
167  /// where `T` is any subclass of `Type`. This function is responsible for
168  /// creating an operation, using the OpBuilder and Location provided, that
169  /// "casts" a range of values into a single value of the given type `T`. It
170  /// must return a Value of the type `T` on success and `nullptr` if
171  /// it failed but other materialization should be attempted. Materialization
172  /// functions must be provided when a type conversion may persist after the
173  /// conversion has finished.
174  ///
175  /// Note: Target materializations may optionally accept an additional Type
176  /// parameter, which is the original type of the SSA value. Furthermore, `T`
177  /// can be a TypeRange; in that case, the function must return a
178  /// SmallVector<Value>.
179 
180  /// This method registers a materialization that will be called when
181  /// converting (potentially multiple) block arguments that were the result of
182  /// a signature conversion of a single block argument, to a single SSA value
183  /// with the old block argument type.
184  template <typename FnT, typename T = typename llvm::function_traits<
185  std::decay_t<FnT>>::template arg_t<1>>
186  void addArgumentMaterialization(FnT &&callback) {
187  argumentMaterializations.emplace_back(
188  wrapMaterialization<T>(std::forward<FnT>(callback)));
189  }
190 
191  /// This method registers a materialization that will be called when
192  /// converting a legal replacement value back to an illegal source type.
193  /// This is used when some uses of the original, illegal value must persist
194  /// beyond the main conversion.
195  template <typename FnT, typename T = typename llvm::function_traits<
196  std::decay_t<FnT>>::template arg_t<1>>
197  void addSourceMaterialization(FnT &&callback) {
198  sourceMaterializations.emplace_back(
199  wrapMaterialization<T>(std::forward<FnT>(callback)));
200  }
201 
202  /// This method registers a materialization that will be called when
203  /// converting an illegal (source) value to a legal (target) type.
204  ///
205  /// Note: For target materializations, users can optionally take the original
206  /// type. This type may be different from the type of the input. For example,
207  /// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
208  /// (type "t1") with "v2" (type "t2"). Then a different conversion pattern
209  /// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
210  /// that "P2" determines that the legalized type of "t1" is "t3", which may
211  /// be different from "t2". In this example, the target materialization
212  /// will be invoked with: outputType = "t3", inputs = "v2",
213  // originalType = "t1". Note that the original type "t1" cannot be recovered
214  /// from just "t3" and "v2"; that's why the originalType parameter exists.
215  ///
216  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
217  /// that case the materialization produces a SmallVector<Value>.
218  template <typename FnT, typename T = typename llvm::function_traits<
219  std::decay_t<FnT>>::template arg_t<1>>
220  void addTargetMaterialization(FnT &&callback) {
221  targetMaterializations.emplace_back(
222  wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
223  }
224 
225  /// Register a conversion function for attributes within types. Type
226  /// converters may call this function in order to allow hoking into the
227  /// translation of attributes that exist within types. For example, a type
228  /// converter for the `memref` type could use these conversions to convert
229  /// memory spaces or layouts in an extensible way.
230  ///
231  /// The conversion functions take a non-null Type or subclass of Type and a
232  /// non-null Attribute (or subclass of Attribute), and returns a
233  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
234  /// which may be `nullptr`, representing the conversion's success,
235  /// `AttributeConversionResult::na()` (the default empty value), indicating
236  /// that the conversion function did not apply and that further conversion
237  /// functions should be checked, or `AttributeConversionResult::abort()`
238  /// indicating that the conversion process should be aborted.
239  ///
240  /// Registered conversion functions are callled in the reverse of the order in
241  /// which they were registered.
242  template <
243  typename FnT,
244  typename T =
245  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
246  typename A =
247  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
248  void addTypeAttributeConversion(FnT &&callback) {
249  registerTypeAttributeConversion(
250  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
251  }
252 
253  /// Convert the given type. This function should return failure if no valid
254  /// conversion exists, success otherwise. If the new set of types is empty,
255  /// the type is removed and any usages of the existing value are expected to
256  /// be removed during conversion.
257  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
258 
259  /// This hook simplifies defining 1-1 type conversions. This function returns
260  /// the type to convert to on success, and a null type on failure.
261  Type convertType(Type t) const;
262 
263  /// Attempts a 1-1 type conversion, expecting the result type to be
264  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
265  /// and a null type on conversion or cast failure.
266  template <typename TargetType>
267  TargetType convertType(Type t) const {
268  return dyn_cast_or_null<TargetType>(convertType(t));
269  }
270 
271  /// Convert the given set of types, filling 'results' as necessary. This
272  /// returns failure if the conversion of any of the types fails, success
273  /// otherwise.
274  LogicalResult convertTypes(TypeRange types,
275  SmallVectorImpl<Type> &results) const;
276 
277  /// Return true if the given type is legal for this type converter, i.e. the
278  /// type converts to itself.
279  bool isLegal(Type type) const;
280 
281  /// Return true if all of the given types are legal for this type converter.
282  template <typename RangeT>
283  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
284  !std::is_convertible<RangeT, Operation *>::value,
285  bool>
286  isLegal(RangeT &&range) const {
287  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
288  }
289  /// Return true if the given operation has legal operand and result types.
290  bool isLegal(Operation *op) const;
291 
292  /// Return true if the types of block arguments within the region are legal.
293  bool isLegal(Region *region) const;
294 
295  /// Return true if the inputs and outputs of the given function type are
296  /// legal.
297  bool isSignatureLegal(FunctionType ty) const;
298 
299  /// This method allows for converting a specific argument of a signature. It
300  /// takes as inputs the original argument input number, type.
301  /// On success, it populates 'result' with any new mappings.
302  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
303  SignatureConversion &result) const;
304  LogicalResult convertSignatureArgs(TypeRange types,
305  SignatureConversion &result,
306  unsigned origInputOffset = 0) const;
307 
308  /// This function converts the type signature of the given block, by invoking
309  /// 'convertSignatureArg' for each argument. This function should return a
310  /// valid conversion for the signature on success, std::nullopt otherwise.
311  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
312 
313  /// Materialize a conversion from a set of types into one result type by
314  /// generating a cast sequence of some kind. See the respective
315  /// `add*Materialization` for more information on the context for these
316  /// methods.
318  Type resultType, ValueRange inputs) const;
320  Type resultType, ValueRange inputs) const;
322  Type resultType, ValueRange inputs,
323  Type originalType = {}) const;
324  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
325  Location loc,
326  TypeRange resultType,
327  ValueRange inputs,
328  Type originalType = {}) const;
329 
330  /// Convert an attribute present `attr` from within the type `type` using
331  /// the registered conversion functions. If no applicable conversion has been
332  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
333  /// is a valid return value for this function.
334  std::optional<Attribute> convertTypeAttribute(Type type,
335  Attribute attr) const;
336 
337 private:
338  /// The signature of the callback used to convert a type. If the new set of
339  /// types is empty, the type is removed and any usages of the existing value
340  /// are expected to be removed during conversion.
341  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
342  Type, SmallVectorImpl<Type> &)>;
343 
344  /// The signature of the callback used to materialize a source/argument
345  /// conversion.
346  ///
347  /// Arguments: builder, result type, inputs, location
348  using MaterializationCallbackFn =
349  std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
350 
351  /// The signature of the callback used to materialize a target conversion.
352  ///
353  /// Arguments: builder, result types, inputs, location, original type
354  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
355  OpBuilder &, TypeRange, ValueRange, Location, Type)>;
356 
357  /// The signature of the callback used to convert a type attribute.
358  using TypeAttributeConversionCallbackFn =
359  std::function<AttributeConversionResult(Type, Attribute)>;
360 
361  /// Generate a wrapper for the given callback. This allows for accepting
362  /// different callback forms, that all compose into a single version.
363  /// With callback of form: `std::optional<Type>(T)`
364  template <typename T, typename FnT>
365  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
366  wrapCallback(FnT &&callback) const {
367  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
368  T type, SmallVectorImpl<Type> &results) {
369  if (std::optional<Type> resultOpt = callback(type)) {
370  bool wasSuccess = static_cast<bool>(*resultOpt);
371  if (wasSuccess)
372  results.push_back(*resultOpt);
373  return std::optional<LogicalResult>(success(wasSuccess));
374  }
375  return std::optional<LogicalResult>();
376  });
377  }
378  /// With callback of form: `std::optional<LogicalResult>(
379  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
380  template <typename T, typename FnT>
381  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
382  ConversionCallbackFn>
383  wrapCallback(FnT &&callback) const {
384  return [callback = std::forward<FnT>(callback)](
385  Type type,
386  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
387  T derivedType = dyn_cast<T>(type);
388  if (!derivedType)
389  return std::nullopt;
390  return callback(derivedType, results);
391  };
392  }
393 
394  /// Register a type conversion.
395  void registerConversion(ConversionCallbackFn callback) {
396  conversions.emplace_back(std::move(callback));
397  cachedDirectConversions.clear();
398  cachedMultiConversions.clear();
399  }
400 
401  /// Generate a wrapper for the given argument/source materialization
402  /// callback. The callback may take any subclass of `Type` and the
403  /// wrapper will check for the target type to be of the expected class
404  /// before calling the callback.
405  template <typename T, typename FnT>
406  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
407  return [callback = std::forward<FnT>(callback)](
408  OpBuilder &builder, Type resultType, ValueRange inputs,
409  Location loc) -> Value {
410  if (T derivedType = dyn_cast<T>(resultType))
411  return callback(builder, derivedType, inputs, loc);
412  return Value();
413  };
414  }
415 
416  /// Generate a wrapper for the given target materialization callback.
417  /// The callback may take any subclass of `Type` and the wrapper will check
418  /// for the target type to be of the expected class before calling the
419  /// callback.
420  ///
421  /// With callback of form:
422  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
423  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
424  template <typename T, typename FnT>
425  std::enable_if_t<
426  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
427  TargetMaterializationCallbackFn>
428  wrapTargetMaterialization(FnT &&callback) const {
429  return [callback = std::forward<FnT>(callback)](
430  OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
431  Location loc, Type originalType) -> SmallVector<Value> {
432  SmallVector<Value> result;
433  if constexpr (std::is_same<T, TypeRange>::value) {
434  // This is a 1:N target materialization. Return the produces values
435  // directly.
436  result = callback(builder, resultTypes, inputs, loc, originalType);
437  } else if constexpr (std::is_assignable<Type, T>::value) {
438  // This is a 1:1 target materialization. Invoke the callback only if a
439  // single SSA value is requested.
440  if (resultTypes.size() == 1) {
441  // Invoke the callback only if the type class of the callback matches
442  // the requested result type.
443  if (T derivedType = dyn_cast<T>(resultTypes.front())) {
444  // 1:1 materializations produce single values, but we store 1:N
445  // target materialization functions in the type converter. Wrap the
446  // result value in a SmallVector<Value>.
447  Value val =
448  callback(builder, derivedType, inputs, loc, originalType);
449  if (val)
450  result.push_back(val);
451  }
452  }
453  } else {
454  static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
455  }
456  return result;
457  };
458  }
459  /// With callback of form:
460  /// - Value(OpBuilder &, T, ValueRange, Location)
461  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
462  template <typename T, typename FnT>
463  std::enable_if_t<
464  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
465  TargetMaterializationCallbackFn>
466  wrapTargetMaterialization(FnT &&callback) const {
467  return wrapTargetMaterialization<T>(
468  [callback = std::forward<FnT>(callback)](
469  OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
470  Type originalType) {
471  return callback(builder, resultTypes, inputs, loc);
472  });
473  }
474 
475  /// Generate a wrapper for the given memory space conversion callback. The
476  /// callback may take any subclass of `Attribute` and the wrapper will check
477  /// for the target attribute to be of the expected class before calling the
478  /// callback.
479  template <typename T, typename A, typename FnT>
480  TypeAttributeConversionCallbackFn
481  wrapTypeAttributeConversion(FnT &&callback) const {
482  return [callback = std::forward<FnT>(callback)](
483  Type type, Attribute attr) -> AttributeConversionResult {
484  if (T derivedType = dyn_cast<T>(type)) {
485  if (A derivedAttr = dyn_cast_or_null<A>(attr))
486  return callback(derivedType, derivedAttr);
487  }
489  };
490  }
491 
492  /// Register a memory space conversion, clearing caches.
493  void
494  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
495  typeAttributeConversions.emplace_back(std::move(callback));
496  // Clear type conversions in case a memory space is lingering inside.
497  cachedDirectConversions.clear();
498  cachedMultiConversions.clear();
499  }
500 
501  /// The set of registered conversion functions.
502  SmallVector<ConversionCallbackFn, 4> conversions;
503 
504  /// The list of registered materialization functions.
505  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
506  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
507  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
508 
509  /// The list of registered type attribute conversion functions.
510  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
511 
512  /// A set of cached conversions to avoid recomputing in the common case.
513  /// Direct 1-1 conversions are the most common, so this cache stores the
514  /// successful 1-1 conversions as well as all failed conversions.
515  mutable DenseMap<Type, Type> cachedDirectConversions;
516  /// This cache stores the successful 1->N conversions, where N != 1.
517  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
518  /// A mutex used for cache access
519  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
520 };
521 
522 //===----------------------------------------------------------------------===//
523 // Conversion Patterns
524 //===----------------------------------------------------------------------===//
525 
526 /// Base class for the conversion patterns. This pattern class enables type
527 /// conversions, and other uses specific to the conversion framework. As such,
528 /// patterns of this type can only be used with the 'apply*' methods below.
530 public:
531  /// Hook for derived classes to implement rewriting. `op` is the (first)
532  /// operation matched by the pattern, `operands` is a list of the rewritten
533  /// operand values that are passed to `op`, `rewriter` can be used to emit the
534  /// new operations. This function should not fail. If some specific cases of
535  /// the operation are not supported, these cases should not be matched.
536  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
537  ConversionPatternRewriter &rewriter) const {
538  llvm_unreachable("unimplemented rewrite");
539  }
540 
541  /// Hook for derived classes to implement combined matching and rewriting.
542  virtual LogicalResult
544  ConversionPatternRewriter &rewriter) const {
545  if (failed(match(op)))
546  return failure();
547  rewrite(op, operands, rewriter);
548  return success();
549  }
550 
551  /// Attempt to match and rewrite the IR root at the specified operation.
552  LogicalResult matchAndRewrite(Operation *op,
553  PatternRewriter &rewriter) const final;
554 
555  /// Return the type converter held by this pattern, or nullptr if the pattern
556  /// does not require type conversion.
557  const TypeConverter *getTypeConverter() const { return typeConverter; }
558 
559  template <typename ConverterTy>
560  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
561  const ConverterTy *>
563  return static_cast<const ConverterTy *>(typeConverter);
564  }
565 
566 protected:
567  /// See `RewritePattern::RewritePattern` for information on the other
568  /// available constructors.
569  using RewritePattern::RewritePattern;
570  /// Construct a conversion pattern with the given converter, and forward the
571  /// remaining arguments to RewritePattern.
572  template <typename... Args>
574  : RewritePattern(std::forward<Args>(args)...),
576 
577 protected:
578  /// An optional type converter for use by this pattern.
579  const TypeConverter *typeConverter = nullptr;
580 
581 private:
583 };
584 
585 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
586 /// matching and rewriting against an instance of a derived operation class as
587 /// opposed to a raw Operation.
588 template <typename SourceOp>
590 public:
591  using OpAdaptor = typename SourceOp::Adaptor;
592 
594  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
596  PatternBenefit benefit = 1)
597  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
598  context) {}
599 
600  /// Wrappers around the ConversionPattern methods that pass the derived op
601  /// type.
602  LogicalResult match(Operation *op) const final {
603  return match(cast<SourceOp>(op));
604  }
605  void rewrite(Operation *op, ArrayRef<Value> operands,
606  ConversionPatternRewriter &rewriter) const final {
607  auto sourceOp = cast<SourceOp>(op);
608  rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
609  }
610  LogicalResult
612  ConversionPatternRewriter &rewriter) const final {
613  auto sourceOp = cast<SourceOp>(op);
614  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
615  }
616 
617  /// Rewrite and Match methods that operate on the SourceOp type. These must be
618  /// overridden by the derived pattern class.
619  virtual LogicalResult match(SourceOp op) const {
620  llvm_unreachable("must override match or matchAndRewrite");
621  }
622  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
623  ConversionPatternRewriter &rewriter) const {
624  llvm_unreachable("must override matchAndRewrite or a rewrite method");
625  }
626  virtual LogicalResult
627  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
628  ConversionPatternRewriter &rewriter) const {
629  if (failed(match(op)))
630  return failure();
631  rewrite(op, adaptor, rewriter);
632  return success();
633  }
634 
635 private:
637 };
638 
639 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
640 /// allows for matching and rewriting against an instance of an OpInterface
641 /// class as opposed to a raw Operation.
642 template <typename SourceOp>
644 public:
647  SourceOp::getInterfaceID(), benefit, context) {}
649  MLIRContext *context, PatternBenefit benefit = 1)
651  SourceOp::getInterfaceID(), benefit, context) {}
652 
653  /// Wrappers around the ConversionPattern methods that pass the derived op
654  /// type.
655  void rewrite(Operation *op, ArrayRef<Value> operands,
656  ConversionPatternRewriter &rewriter) const final {
657  rewrite(cast<SourceOp>(op), operands, rewriter);
658  }
659  LogicalResult
661  ConversionPatternRewriter &rewriter) const final {
662  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
663  }
664 
665  /// Rewrite and Match methods that operate on the SourceOp type. These must be
666  /// overridden by the derived pattern class.
667  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
668  ConversionPatternRewriter &rewriter) const {
669  llvm_unreachable("must override matchAndRewrite or a rewrite method");
670  }
671  virtual LogicalResult
672  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
673  ConversionPatternRewriter &rewriter) const {
674  if (failed(match(op)))
675  return failure();
676  rewrite(op, operands, rewriter);
677  return success();
678  }
679 
680 private:
682 };
683 
684 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
685 /// for matching and rewriting against instances of an operation that possess a
686 /// given trait.
687 template <template <typename> class TraitType>
689 public:
692  TypeID::get<TraitType>(), benefit, context) {}
694  MLIRContext *context, PatternBenefit benefit = 1)
696  TypeID::get<TraitType>(), benefit, context) {}
697 };
698 
699 /// Generic utility to convert op result types according to type converter
700 /// without knowing exact op type.
701 /// Clones existing op with new result types and returns it.
702 FailureOr<Operation *>
703 convertOpResultTypes(Operation *op, ValueRange operands,
704  const TypeConverter &converter,
705  ConversionPatternRewriter &rewriter);
706 
707 /// Add a pattern to the given pattern list to convert the signature of a
708 /// FunctionOpInterface op with the given type converter. This only supports
709 /// ops which use FunctionType to represent their type.
711  StringRef functionLikeOpName, RewritePatternSet &patterns,
712  const TypeConverter &converter);
713 
714 template <typename FuncOpT>
716  RewritePatternSet &patterns, const TypeConverter &converter) {
717  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
718  patterns, converter);
719 }
720 
722  RewritePatternSet &patterns, const TypeConverter &converter);
723 
724 //===----------------------------------------------------------------------===//
725 // Conversion PatternRewriter
726 //===----------------------------------------------------------------------===//
727 
728 namespace detail {
729 struct ConversionPatternRewriterImpl;
730 } // namespace detail
731 
732 /// This class implements a pattern rewriter for use with ConversionPatterns. It
733 /// extends the base PatternRewriter and provides special conversion specific
734 /// hooks.
736 public:
738 
739  /// Apply a signature conversion to given block. This replaces the block with
740  /// a new block containing the updated signature. The operations of the given
741  /// block are inlined into the newly-created block, which is returned.
742  ///
743  /// If no block argument types are changing, the original block will be
744  /// left in place and returned.
745  ///
746  /// A signature converison must be provided. (Type converters can construct
747  /// a signature conversion with `convertBlockSignature`.)
748  ///
749  /// Optionally, a type converter can be provided to build materializations.
750  /// Note: If no type converter was provided or the type converter does not
751  /// specify any suitable argument/target materialization rules, the dialect
752  /// conversion may fail to legalize unresolved materializations.
753  Block *
756  const TypeConverter *converter = nullptr);
757 
758  /// Apply a signature conversion to each block in the given region. This
759  /// replaces each block with a new block containing the updated signature. If
760  /// an updated signature would match the current signature, the respective
761  /// block is left in place as is. (See `applySignatureConversion` for
762  /// details.) The new entry block of the region is returned.
763  ///
764  /// SignatureConversions are computed with the specified type converter.
765  /// This function returns "failure" if the type converter failed to compute
766  /// a SignatureConversion for at least one block.
767  ///
768  /// Optionally, a special SignatureConversion can be specified for the entry
769  /// block. This is because the types of the entry block arguments are often
770  /// tied semantically to the operation.
771  FailureOr<Block *> convertRegionTypes(
772  Region *region, const TypeConverter &converter,
773  TypeConverter::SignatureConversion *entryConversion = nullptr);
774 
775  /// Replace all the uses of the block argument `from` with value `to`.
777 
778  /// Return the converted value of 'key' with a type defined by the type
779  /// converter of the currently executing pattern. Return nullptr in the case
780  /// of failure, the remapped value otherwise.
782 
783  /// Return the converted values that replace 'keys' with types defined by the
784  /// type converter of the currently executing pattern. Returns failure if the
785  /// remap failed, success otherwise.
786  LogicalResult getRemappedValues(ValueRange keys,
787  SmallVectorImpl<Value> &results);
788 
789  //===--------------------------------------------------------------------===//
790  // PatternRewriter Hooks
791  //===--------------------------------------------------------------------===//
792 
793  /// Indicate that the conversion rewriter can recover from rewrite failure.
794  /// Recovery is supported via rollback, allowing for continued processing of
795  /// patterns even if a failure is encountered during the rewrite step.
796  bool canRecoverFromRewriteFailure() const override { return true; }
797 
798  /// Replace the given operation with the new values. The number of op results
799  /// and replacement values must match. The types may differ: the dialect
800  /// conversion driver will reconcile any surviving type mismatches at the end
801  /// of the conversion process with source materializations. The given
802  /// operation is erased.
803  void replaceOp(Operation *op, ValueRange newValues) override;
804 
805  /// Replace the given operation with the results of the new op. The number of
806  /// op results must match. The types may differ: the dialect conversion
807  /// driver will reconcile any surviving type mismatches at the end of the
808  /// conversion process with source materializations. The original operation
809  /// is erased.
810  void replaceOp(Operation *op, Operation *newOp) override;
811 
812  /// Replace the given operation with the new value ranges. The number of op
813  /// results and value ranges must match. If an original SSA value is replaced
814  /// by multiple SSA values (i.e., a value range has more than 1 element), the
815  /// conversion driver will insert an argument materialization to convert the
816  /// N SSA values back into 1 SSA value of the original type. The given
817  /// operation is erased.
818  ///
819  /// Note: The argument materialization is a workaround until we have full 1:N
820  /// support in the dialect conversion. (It is going to disappear from both
821  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
823 
824  /// PatternRewriter hook for erasing a dead operation. The uses of this
825  /// operation *must* be made dead by the end of the conversion process,
826  /// otherwise an assert will be issued.
827  void eraseOp(Operation *op) override;
828 
829  /// PatternRewriter hook for erase all operations in a block. This is not yet
830  /// implemented for dialect conversion.
831  void eraseBlock(Block *block) override;
832 
833  /// PatternRewriter hook for inlining the ops of a block into another block.
834  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
835  ValueRange argValues = std::nullopt) override;
837 
838  /// PatternRewriter hook for updating the given operation in-place.
839  /// Note: These methods only track updates to the given operation itself,
840  /// and not nested regions. Updates to regions will still require notification
841  /// through other more specific hooks above.
842  void startOpModification(Operation *op) override;
843 
844  /// PatternRewriter hook for updating the given operation in-place.
845  void finalizeOpModification(Operation *op) override;
846 
847  /// PatternRewriter hook for updating the given operation in-place.
848  void cancelOpModification(Operation *op) override;
849 
850  /// Return a reference to the internal implementation.
852 
853 private:
854  // Allow OperationConverter to construct new rewriters.
855  friend struct OperationConverter;
856 
857  /// Conversion pattern rewriters must not be used outside of dialect
858  /// conversions. They apply some IR rewrites in a delayed fashion and could
859  /// bring the IR into an inconsistent state when used standalone.
861  const ConversionConfig &config);
862 
863  // Hide unsupported pattern rewriter API.
865 
866  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
867 };
868 
869 //===----------------------------------------------------------------------===//
870 // ConversionTarget
871 //===----------------------------------------------------------------------===//
872 
873 /// This class describes a specific conversion target.
875 public:
876  /// This enumeration corresponds to the specific action to take when
877  /// considering an operation legal for this conversion target.
878  enum class LegalizationAction {
879  /// The target supports this operation.
880  Legal,
881 
882  /// This operation has dynamic legalization constraints that must be checked
883  /// by the target.
884  Dynamic,
885 
886  /// The target explicitly does not support this operation.
887  Illegal,
888  };
889 
890  /// A structure containing additional information describing a specific legal
891  /// operation instance.
892  struct LegalOpDetails {
893  /// A flag that indicates if this operation is 'recursively' legal. This
894  /// means that if an operation is legal, either statically or dynamically,
895  /// all of the operations nested within are also considered legal.
896  bool isRecursivelyLegal = false;
897  };
898 
899  /// The signature of the callback used to determine if an operation is
900  /// dynamically legal on the target.
902  std::function<std::optional<bool>(Operation *)>;
903 
904  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
905  virtual ~ConversionTarget() = default;
906 
907  //===--------------------------------------------------------------------===//
908  // Legality Registration
909  //===--------------------------------------------------------------------===//
910 
911  /// Register a legality action for the given operation.
913  template <typename OpT>
915  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
916  }
917 
918  /// Register the given operations as legal.
921  }
922  template <typename OpT>
923  void addLegalOp() {
924  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
925  }
926  template <typename OpT, typename OpT2, typename... OpTs>
927  void addLegalOp() {
928  addLegalOp<OpT>();
929  addLegalOp<OpT2, OpTs...>();
930  }
931 
932  /// Register the given operation as dynamically legal and set the dynamic
933  /// legalization callback to the one provided.
935  const DynamicLegalityCallbackFn &callback) {
937  setLegalityCallback(op, callback);
938  }
939  template <typename OpT>
941  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
942  callback);
943  }
944  template <typename OpT, typename OpT2, typename... OpTs>
946  addDynamicallyLegalOp<OpT>(callback);
947  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
948  }
949  template <typename OpT, class Callable>
950  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
951  addDynamicallyLegalOp(Callable &&callback) {
952  addDynamicallyLegalOp<OpT>(
953  [=](Operation *op) { return callback(cast<OpT>(op)); });
954  }
955 
956  /// Register the given operation as illegal, i.e. this operation is known to
957  /// not be supported by this target.
960  }
961  template <typename OpT>
962  void addIllegalOp() {
963  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
964  }
965  template <typename OpT, typename OpT2, typename... OpTs>
966  void addIllegalOp() {
967  addIllegalOp<OpT>();
968  addIllegalOp<OpT2, OpTs...>();
969  }
970 
971  /// Mark an operation, that *must* have either been set as `Legal` or
972  /// `DynamicallyLegal`, as being recursively legal. This means that in
973  /// addition to the operation itself, all of the operations nested within are
974  /// also considered legal. An optional dynamic legality callback may be
975  /// provided to mark subsets of legal instances as recursively legal.
977  const DynamicLegalityCallbackFn &callback);
978  template <typename OpT>
980  OperationName opName(OpT::getOperationName(), &ctx);
981  markOpRecursivelyLegal(opName, callback);
982  }
983  template <typename OpT, typename OpT2, typename... OpTs>
985  markOpRecursivelyLegal<OpT>(callback);
986  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
987  }
988  template <typename OpT, class Callable>
989  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
990  markOpRecursivelyLegal(Callable &&callback) {
991  markOpRecursivelyLegal<OpT>(
992  [=](Operation *op) { return callback(cast<OpT>(op)); });
993  }
994 
995  /// Register a legality action for the given dialects.
996  void setDialectAction(ArrayRef<StringRef> dialectNames,
997  LegalizationAction action);
998 
999  /// Register the operations of the given dialects as legal.
1000  template <typename... Names>
1001  void addLegalDialect(StringRef name, Names... names) {
1002  SmallVector<StringRef, 2> dialectNames({name, names...});
1004  }
1005  template <typename... Args>
1007  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1009  }
1010 
1011  /// Register the operations of the given dialects as dynamically legal, i.e.
1012  /// requiring custom handling by the callback.
1013  template <typename... Names>
1015  StringRef name, Names... names) {
1016  SmallVector<StringRef, 2> dialectNames({name, names...});
1018  setLegalityCallback(dialectNames, callback);
1019  }
1020  template <typename... Args>
1022  addDynamicallyLegalDialect(std::move(callback),
1023  Args::getDialectNamespace()...);
1024  }
1025 
1026  /// Register unknown operations as dynamically legal. For operations(and
1027  /// dialects) that do not have a set legalization action, treat them as
1028  /// dynamically legal and invoke the given callback.
1030  setLegalityCallback(fn);
1031  }
1032 
1033  /// Register the operations of the given dialects as illegal, i.e.
1034  /// operations of this dialect are not supported by the target.
1035  template <typename... Names>
1036  void addIllegalDialect(StringRef name, Names... names) {
1037  SmallVector<StringRef, 2> dialectNames({name, names...});
1039  }
1040  template <typename... Args>
1042  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1044  }
1045 
1046  //===--------------------------------------------------------------------===//
1047  // Legality Querying
1048  //===--------------------------------------------------------------------===//
1049 
1050  /// Get the legality action for the given operation.
1051  std::optional<LegalizationAction> getOpAction(OperationName op) const;
1052 
1053  /// If the given operation instance is legal on this target, a structure
1054  /// containing legality information is returned. If the operation is not
1055  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1056  /// legality wasn't registered by user or dynamic legality callbacks returned
1057  /// None.
1058  ///
1059  /// Note: Legality is actually a 4-state: Legal(recursive=true),
1060  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1061  /// either as Legal or Illegal depending on context.
1062  std::optional<LegalOpDetails> isLegal(Operation *op) const;
1063 
1064  /// Returns true is operation instance is illegal on this target. Returns
1065  /// false if operation is legal, operation legality wasn't registered by user
1066  /// or dynamic legality callbacks returned None.
1067  bool isIllegal(Operation *op) const;
1068 
1069 private:
1070  /// Set the dynamic legality callback for the given operation.
1071  void setLegalityCallback(OperationName name,
1072  const DynamicLegalityCallbackFn &callback);
1073 
1074  /// Set the dynamic legality callback for the given dialects.
1075  void setLegalityCallback(ArrayRef<StringRef> dialects,
1076  const DynamicLegalityCallbackFn &callback);
1077 
1078  /// Set the dynamic legality callback for the unknown ops.
1079  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1080 
1081  /// The set of information that configures the legalization of an operation.
1082  struct LegalizationInfo {
1083  /// The legality action this operation was given.
1085 
1086  /// If some legal instances of this operation may also be recursively legal.
1087  bool isRecursivelyLegal = false;
1088 
1089  /// The legality callback if this operation is dynamically legal.
1090  DynamicLegalityCallbackFn legalityFn;
1091  };
1092 
1093  /// Get the legalization information for the given operation.
1094  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1095 
1096  /// A deterministic mapping of operation name and its respective legality
1097  /// information.
1098  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1099 
1100  /// A set of legality callbacks for given operation names that are used to
1101  /// check if an operation instance is recursively legal.
1102  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1103 
1104  /// A deterministic mapping of dialect name to the specific legality action to
1105  /// take.
1106  llvm::StringMap<LegalizationAction> legalDialects;
1107 
1108  /// A set of dynamic legality callbacks for given dialect names.
1109  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1110 
1111  /// An optional legality callback for unknown operations.
1112  DynamicLegalityCallbackFn unknownLegalityFn;
1113 
1114  /// The current context this target applies to.
1115  MLIRContext &ctx;
1116 };
1117 
1118 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1119 //===----------------------------------------------------------------------===//
1120 // PDL Configuration
1121 //===----------------------------------------------------------------------===//
1122 
1123 /// A PDL configuration that is used to supported dialect conversion
1124 /// functionality.
1126  : public PDLPatternConfigBase<PDLConversionConfig> {
1127 public:
1128  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1129  ~PDLConversionConfig() final = default;
1130 
1131  /// Return the type converter used by this configuration, which may be nullptr
1132  /// if no type conversions are expected.
1133  const TypeConverter *getTypeConverter() const { return converter; }
1134 
1135  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1136  /// pattern.
1137  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1138  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1139 
1140 private:
1141  /// An optional type converter to use for the pattern.
1142  const TypeConverter *converter;
1143 };
1144 
1145 /// Register the dialect conversion PDL functions with the given pattern set.
1146 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1147 
1148 #else
1149 
1150 // Stubs for when PDL in rewriting is not enabled.
1151 
1152 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1153 
1154 class PDLConversionConfig final {
1155 public:
1156  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1157 };
1158 
1159 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1160 
1161 //===----------------------------------------------------------------------===//
1162 // ConversionConfig
1163 //===----------------------------------------------------------------------===//
1164 
1165 /// Dialect conversion configuration.
1167  /// An optional callback used to notify about match failure diagnostics during
1168  /// the conversion. Diagnostics reported to this callback may only be
1169  /// available in debug mode.
1171 
1172  /// Partial conversion only. All operations that are found not to be
1173  /// legalizable are placed in this set. (Note that if there is an op
1174  /// explicitly marked as illegal, the conversion terminates and the set will
1175  /// not necessarily be complete.)
1177 
1178  /// Analysis conversion only. All operations that are found to be legalizable
1179  /// are placed in this set. Note that no actual rewrites are applied to the
1180  /// IR during an analysis conversion and only pre-existing operations are
1181  /// added to the set.
1183 
1184  /// An optional listener that is notified about all IR modifications in case
1185  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1186  /// modifications are visible (i.e., they were all rolled back), or if the
1187  /// dialect conversion is an "analysis conversion", no notifications are
1188  /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1189  ///
1190  /// Note: Notifications are sent in a delayed fashion, when the dialect
1191  /// conversion is guaranteed to succeed. At that point, some IR modifications
1192  /// may already have been materialized. Consequently, operations/blocks that
1193  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1194  /// guaranteed to be valid pointers and accessing op names is allowed. But
1195  /// there are no guarantees about the state of ops/blocks at the time that a
1196  /// callback is triggered.)
1197  ///
1198  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1199  /// and inserted, and later moved to another block. (Moving ops also triggers
1200  /// "notifyOperationInserted".)
1201  ///
1202  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1203  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1204  ///
1205  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1206  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1207  /// immediately performed by the dialect conversion (and rolled back upon
1208  /// failure).
1209  //
1210  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1211  // callback, the previous region/block is provided to the callback, but not
1212  // the iterator pointing to the exact location within the region/block. That
1213  // is because these notifications are sent with a delay (after the IR has
1214  // already been modified) and iterators into past IR state cannot be
1215  // represented at the moment.
1217 
1218  /// If set to "true", the dialect conversion attempts to build source/target/
1219  /// argument materializations through the type converter API in lieu of
1220  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1221  /// at least one materialization could not be built.
1222  ///
1223  /// If set to "false", the dialect conversion does not build any custom
1224  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1225  /// ops to ensure that the resulting IR is valid.
1227 };
1228 
1229 //===----------------------------------------------------------------------===//
1230 // Reconcile Unrealized Casts
1231 //===----------------------------------------------------------------------===//
1232 
1233 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1234 /// left-over ops in `remainingCastOps` (if provided).
1235 ///
1236 /// This function processes cast ops in a worklist-driven fashion. For each
1237 /// cast op, if the chain of input casts eventually reaches a cast op where the
1238 /// input types match the output types of the matched op, replace the matched
1239 /// op with the inputs.
1240 ///
1241 /// Example:
1242 /// %1 = unrealized_conversion_cast %0 : !A to !B
1243 /// %2 = unrealized_conversion_cast %1 : !B to !C
1244 /// %3 = unrealized_conversion_cast %2 : !C to !A
1245 ///
1246 /// In the above example, %0 can be used instead of %3 and all cast ops are
1247 /// folded away.
1250  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1251 
1252 //===----------------------------------------------------------------------===//
1253 // Op Conversion Entry Points
1254 //===----------------------------------------------------------------------===//
1255 
1256 /// Below we define several entry points for operation conversion. It is
1257 /// important to note that the patterns provided to the conversion framework may
1258 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1259 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1260 /// the use of the PatternRewriter.
1261 
1262 /// Apply a partial conversion on the given operations and all nested
1263 /// operations. This method converts as many operations to the target as
1264 /// possible, ignoring operations that failed to legalize. This method only
1265 /// returns failure if there ops explicitly marked as illegal.
1266 LogicalResult
1268  const ConversionTarget &target,
1269  const FrozenRewritePatternSet &patterns,
1270  ConversionConfig config = ConversionConfig());
1271 LogicalResult
1273  const FrozenRewritePatternSet &patterns,
1274  ConversionConfig config = ConversionConfig());
1275 
1276 /// Apply a complete conversion on the given operations, and all nested
1277 /// operations. This method returns failure if the conversion of any operation
1278 /// fails, or if there are unreachable blocks in any of the regions nested
1279 /// within 'ops'.
1280 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1281  const ConversionTarget &target,
1282  const FrozenRewritePatternSet &patterns,
1283  ConversionConfig config = ConversionConfig());
1284 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1285  const FrozenRewritePatternSet &patterns,
1286  ConversionConfig config = ConversionConfig());
1287 
1288 /// Apply an analysis conversion on the given operations, and all nested
1289 /// operations. This method analyzes which operations would be successfully
1290 /// converted to the target if a conversion was applied. All operations that
1291 /// were found to be legalizable to the given 'target' are placed within the
1292 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1293 /// applied to the operations on success. This method only returns failure if
1294 /// there are unreachable blocks in any of the regions nested within 'ops'.
1295 LogicalResult
1297  const FrozenRewritePatternSet &patterns,
1298  ConversionConfig config = ConversionConfig());
1299 LogicalResult
1301  const FrozenRewritePatternSet &patterns,
1302  ConversionConfig config = ConversionConfig());
1303 } // namespace mlir
1304 
1305 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:324
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
This struct represents a range of new types or a single value that remaps an existing signature input...