MLIR  20.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  argumentMaterializations(other.argumentMaterializations),
49  sourceMaterializations(other.sourceMaterializations),
50  targetMaterializations(other.targetMaterializations),
51  typeAttributeConversions(other.typeAttributeConversions) {}
53  conversions = other.conversions;
54  argumentMaterializations = other.argumentMaterializations;
55  sourceMaterializations = other.sourceMaterializations;
56  targetMaterializations = other.targetMaterializations;
57  typeAttributeConversions = other.typeAttributeConversions;
58  return *this;
59  }
60 
61  /// This class provides all of the information necessary to convert a type
62  /// signature.
64  public:
65  SignatureConversion(unsigned numOrigInputs)
66  : remappedInputs(numOrigInputs) {}
67 
68  /// This struct represents a range of new types or a single value that
69  /// remaps an existing signature input.
70  struct InputMapping {
71  size_t inputNo, size;
73  };
74 
75  /// Return the argument types for the new signature.
76  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
77 
78  /// Get the input mapping for the given argument.
79  std::optional<InputMapping> getInputMapping(unsigned input) const {
80  return remappedInputs[input];
81  }
82 
83  //===------------------------------------------------------------------===//
84  // Conversion Hooks
85  //===------------------------------------------------------------------===//
86 
87  /// Remap an input of the original signature with a new set of types. The
88  /// new types are appended to the new signature conversion.
89  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
90 
91  /// Append new input types to the signature conversion, this should only be
92  /// used if the new types are not intended to remap an existing input.
93  void addInputs(ArrayRef<Type> types);
94 
95  /// Remap an input of the original signature to another `replacement`
96  /// value. This drops the original argument.
97  void remapInput(unsigned origInputNo, Value replacement);
98 
99  private:
100  /// Remap an input of the original signature with a range of types in the
101  /// new signature.
102  void remapInput(unsigned origInputNo, unsigned newInputNo,
103  unsigned newInputCount = 1);
104 
105  /// The remapping information for each of the original arguments.
106  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
107 
108  /// The set of new argument types.
109  SmallVector<Type, 4> argTypes;
110  };
111 
112  /// The general result of a type attribute conversion callback, allowing
113  /// for early termination. The default constructor creates the na case.
115  public:
116  constexpr AttributeConversionResult() : impl() {}
117  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
118 
120  static AttributeConversionResult na();
122 
123  bool hasResult() const;
124  bool isNa() const;
125  bool isAbort() const;
126 
127  Attribute getResult() const;
128 
129  private:
130  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
131 
132  llvm::PointerIntPair<Attribute, 2> impl;
133  // Note that na is 0 so that we can use PointerIntPair's default
134  // constructor.
135  static constexpr unsigned naTag = 0;
136  static constexpr unsigned resultTag = 1;
137  static constexpr unsigned abortTag = 2;
138  };
139 
140  /// Register a conversion function. A conversion function must be convertible
141  /// to any of the following forms (where `T` is a class derived from `Type`):
142  ///
143  /// * std::optional<Type>(T)
144  /// - This form represents a 1-1 type conversion. It should return nullptr
145  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
146  /// the converter is allowed to try another conversion function to
147  /// perform the conversion.
148  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
149  /// - This form represents a 1-N type conversion. It should return
150  /// `failure` or `std::nullopt` to signify a failed conversion. If the
151  /// new set of types is empty, the type is removed and any usages of the
152  /// existing value are expected to be removed during conversion. If
153  /// `std::nullopt` is returned, the converter is allowed to try another
154  /// conversion function to perform the conversion.
155  ///
156  /// Note: When attempting to convert a type, e.g. via 'convertType', the
157  /// mostly recently added conversions will be invoked first.
158  template <typename FnT, typename T = typename llvm::function_traits<
159  std::decay_t<FnT>>::template arg_t<0>>
160  void addConversion(FnT &&callback) {
161  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
162  }
163 
164  /// All of the following materializations require function objects that are
165  /// convertible to the following form:
166  /// `Value(OpBuilder &, T, ValueRange, Location)`,
167  /// where `T` is any subclass of `Type`. This function is responsible for
168  /// creating an operation, using the OpBuilder and Location provided, that
169  /// "casts" a range of values into a single value of the given type `T`. It
170  /// must return a Value of the type `T` on success and `nullptr` if
171  /// it failed but other materialization should be attempted. Materialization
172  /// functions must be provided when a type conversion may persist after the
173  /// conversion has finished.
174  ///
175  /// Note: Target materializations may optionally accept an additional Type
176  /// parameter, which is the original type of the SSA value. Furthermore, `T`
177  /// can be a TypeRange; in that case, the function must return a
178  /// SmallVector<Value>.
179 
180  /// This method registers a materialization that will be called when
181  /// converting (potentially multiple) block arguments that were the result of
182  /// a signature conversion of a single block argument, to a single SSA value
183  /// with the old block argument type.
184  ///
185  /// Note: Argument materializations are used only with the 1:N dialect
186  /// conversion driver. The 1:N dialect conversion driver will be removed soon
187  /// and so will be argument materializations.
188  template <typename FnT, typename T = typename llvm::function_traits<
189  std::decay_t<FnT>>::template arg_t<1>>
190  void addArgumentMaterialization(FnT &&callback) {
191  argumentMaterializations.emplace_back(
192  wrapMaterialization<T>(std::forward<FnT>(callback)));
193  }
194 
195  /// This method registers a materialization that will be called when
196  /// converting a replacement value back to its original source type.
197  /// This is used when some uses of the original value persist beyond the main
198  /// conversion.
199  template <typename FnT, typename T = typename llvm::function_traits<
200  std::decay_t<FnT>>::template arg_t<1>>
201  void addSourceMaterialization(FnT &&callback) {
202  sourceMaterializations.emplace_back(
203  wrapMaterialization<T>(std::forward<FnT>(callback)));
204  }
205 
206  /// This method registers a materialization that will be called when
207  /// converting a value to a target type according to a pattern's type
208  /// converter.
209  ///
210  /// Note: Target materializations can optionally inspect the "original"
211  /// type. This type may be different from the type of the input value.
212  /// For example, let's assume that a conversion pattern "P1" replaced an SSA
213  /// value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion
214  /// pattern "P2" matches an op that has "v1" as an operand. Let's furthermore
215  /// assume that "P2" determines that the converted target type of "t1" is
216  /// "t3", which may be different from "t2". In this example, the target
217  /// materialization will be invoked with: outputType = "t3", inputs = "v2",
218  /// originalType = "t1". Note that the original type "t1" cannot be recovered
219  /// from just "t3" and "v2"; that's why the originalType parameter exists.
220  ///
221  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
222  /// that case the materialization produces a SmallVector<Value>.
223  template <typename FnT, typename T = typename llvm::function_traits<
224  std::decay_t<FnT>>::template arg_t<1>>
225  void addTargetMaterialization(FnT &&callback) {
226  targetMaterializations.emplace_back(
227  wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
228  }
229 
230  /// Register a conversion function for attributes within types. Type
231  /// converters may call this function in order to allow hoking into the
232  /// translation of attributes that exist within types. For example, a type
233  /// converter for the `memref` type could use these conversions to convert
234  /// memory spaces or layouts in an extensible way.
235  ///
236  /// The conversion functions take a non-null Type or subclass of Type and a
237  /// non-null Attribute (or subclass of Attribute), and returns a
238  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
239  /// which may be `nullptr`, representing the conversion's success,
240  /// `AttributeConversionResult::na()` (the default empty value), indicating
241  /// that the conversion function did not apply and that further conversion
242  /// functions should be checked, or `AttributeConversionResult::abort()`
243  /// indicating that the conversion process should be aborted.
244  ///
245  /// Registered conversion functions are callled in the reverse of the order in
246  /// which they were registered.
247  template <
248  typename FnT,
249  typename T =
250  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
251  typename A =
252  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
253  void addTypeAttributeConversion(FnT &&callback) {
254  registerTypeAttributeConversion(
255  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
256  }
257 
258  /// Convert the given type. This function should return failure if no valid
259  /// conversion exists, success otherwise. If the new set of types is empty,
260  /// the type is removed and any usages of the existing value are expected to
261  /// be removed during conversion.
262  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
263 
264  /// This hook simplifies defining 1-1 type conversions. This function returns
265  /// the type to convert to on success, and a null type on failure.
266  Type convertType(Type t) const;
267 
268  /// Attempts a 1-1 type conversion, expecting the result type to be
269  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
270  /// and a null type on conversion or cast failure.
271  template <typename TargetType>
272  TargetType convertType(Type t) const {
273  return dyn_cast_or_null<TargetType>(convertType(t));
274  }
275 
276  /// Convert the given set of types, filling 'results' as necessary. This
277  /// returns failure if the conversion of any of the types fails, success
278  /// otherwise.
279  LogicalResult convertTypes(TypeRange types,
280  SmallVectorImpl<Type> &results) const;
281 
282  /// Return true if the given type is legal for this type converter, i.e. the
283  /// type converts to itself.
284  bool isLegal(Type type) const;
285 
286  /// Return true if all of the given types are legal for this type converter.
287  template <typename RangeT>
288  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
289  !std::is_convertible<RangeT, Operation *>::value,
290  bool>
291  isLegal(RangeT &&range) const {
292  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
293  }
294  /// Return true if the given operation has legal operand and result types.
295  bool isLegal(Operation *op) const;
296 
297  /// Return true if the types of block arguments within the region are legal.
298  bool isLegal(Region *region) const;
299 
300  /// Return true if the inputs and outputs of the given function type are
301  /// legal.
302  bool isSignatureLegal(FunctionType ty) const;
303 
304  /// This method allows for converting a specific argument of a signature. It
305  /// takes as inputs the original argument input number, type.
306  /// On success, it populates 'result' with any new mappings.
307  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
308  SignatureConversion &result) const;
309  LogicalResult convertSignatureArgs(TypeRange types,
310  SignatureConversion &result,
311  unsigned origInputOffset = 0) const;
312 
313  /// This function converts the type signature of the given block, by invoking
314  /// 'convertSignatureArg' for each argument. This function should return a
315  /// valid conversion for the signature on success, std::nullopt otherwise.
316  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
317 
318  /// Materialize a conversion from a set of types into one result type by
319  /// generating a cast sequence of some kind. See the respective
320  /// `add*Materialization` for more information on the context for these
321  /// methods.
323  Type resultType, ValueRange inputs) const;
325  Type resultType, ValueRange inputs) const;
327  Type resultType, ValueRange inputs,
328  Type originalType = {}) const;
329  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
330  Location loc,
331  TypeRange resultType,
332  ValueRange inputs,
333  Type originalType = {}) const;
334 
335  /// Convert an attribute present `attr` from within the type `type` using
336  /// the registered conversion functions. If no applicable conversion has been
337  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
338  /// is a valid return value for this function.
339  std::optional<Attribute> convertTypeAttribute(Type type,
340  Attribute attr) const;
341 
342 private:
343  /// The signature of the callback used to convert a type. If the new set of
344  /// types is empty, the type is removed and any usages of the existing value
345  /// are expected to be removed during conversion.
346  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
347  Type, SmallVectorImpl<Type> &)>;
348 
349  /// The signature of the callback used to materialize a source/argument
350  /// conversion.
351  ///
352  /// Arguments: builder, result type, inputs, location
353  using MaterializationCallbackFn =
354  std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
355 
356  /// The signature of the callback used to materialize a target conversion.
357  ///
358  /// Arguments: builder, result types, inputs, location, original type
359  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
360  OpBuilder &, TypeRange, ValueRange, Location, Type)>;
361 
362  /// The signature of the callback used to convert a type attribute.
363  using TypeAttributeConversionCallbackFn =
364  std::function<AttributeConversionResult(Type, Attribute)>;
365 
366  /// Generate a wrapper for the given callback. This allows for accepting
367  /// different callback forms, that all compose into a single version.
368  /// With callback of form: `std::optional<Type>(T)`
369  template <typename T, typename FnT>
370  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
371  wrapCallback(FnT &&callback) const {
372  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
373  T type, SmallVectorImpl<Type> &results) {
374  if (std::optional<Type> resultOpt = callback(type)) {
375  bool wasSuccess = static_cast<bool>(*resultOpt);
376  if (wasSuccess)
377  results.push_back(*resultOpt);
378  return std::optional<LogicalResult>(success(wasSuccess));
379  }
380  return std::optional<LogicalResult>();
381  });
382  }
383  /// With callback of form: `std::optional<LogicalResult>(
384  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
385  template <typename T, typename FnT>
386  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
387  ConversionCallbackFn>
388  wrapCallback(FnT &&callback) const {
389  return [callback = std::forward<FnT>(callback)](
390  Type type,
391  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
392  T derivedType = dyn_cast<T>(type);
393  if (!derivedType)
394  return std::nullopt;
395  return callback(derivedType, results);
396  };
397  }
398 
399  /// Register a type conversion.
400  void registerConversion(ConversionCallbackFn callback) {
401  conversions.emplace_back(std::move(callback));
402  cachedDirectConversions.clear();
403  cachedMultiConversions.clear();
404  }
405 
406  /// Generate a wrapper for the given argument/source materialization
407  /// callback. The callback may take any subclass of `Type` and the
408  /// wrapper will check for the target type to be of the expected class
409  /// before calling the callback.
410  template <typename T, typename FnT>
411  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
412  return [callback = std::forward<FnT>(callback)](
413  OpBuilder &builder, Type resultType, ValueRange inputs,
414  Location loc) -> Value {
415  if (T derivedType = dyn_cast<T>(resultType))
416  return callback(builder, derivedType, inputs, loc);
417  return Value();
418  };
419  }
420 
421  /// Generate a wrapper for the given target materialization callback.
422  /// The callback may take any subclass of `Type` and the wrapper will check
423  /// for the target type to be of the expected class before calling the
424  /// callback.
425  ///
426  /// With callback of form:
427  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
428  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
429  template <typename T, typename FnT>
430  std::enable_if_t<
431  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
432  TargetMaterializationCallbackFn>
433  wrapTargetMaterialization(FnT &&callback) const {
434  return [callback = std::forward<FnT>(callback)](
435  OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
436  Location loc, Type originalType) -> SmallVector<Value> {
437  SmallVector<Value> result;
438  if constexpr (std::is_same<T, TypeRange>::value) {
439  // This is a 1:N target materialization. Return the produces values
440  // directly.
441  result = callback(builder, resultTypes, inputs, loc, originalType);
442  } else if constexpr (std::is_assignable<Type, T>::value) {
443  // This is a 1:1 target materialization. Invoke the callback only if a
444  // single SSA value is requested.
445  if (resultTypes.size() == 1) {
446  // Invoke the callback only if the type class of the callback matches
447  // the requested result type.
448  if (T derivedType = dyn_cast<T>(resultTypes.front())) {
449  // 1:1 materializations produce single values, but we store 1:N
450  // target materialization functions in the type converter. Wrap the
451  // result value in a SmallVector<Value>.
452  Value val =
453  callback(builder, derivedType, inputs, loc, originalType);
454  if (val)
455  result.push_back(val);
456  }
457  }
458  } else {
459  static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
460  }
461  return result;
462  };
463  }
464  /// With callback of form:
465  /// - Value(OpBuilder &, T, ValueRange, Location)
466  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
467  template <typename T, typename FnT>
468  std::enable_if_t<
469  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
470  TargetMaterializationCallbackFn>
471  wrapTargetMaterialization(FnT &&callback) const {
472  return wrapTargetMaterialization<T>(
473  [callback = std::forward<FnT>(callback)](
474  OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
475  Type originalType) {
476  return callback(builder, resultTypes, inputs, loc);
477  });
478  }
479 
480  /// Generate a wrapper for the given memory space conversion callback. The
481  /// callback may take any subclass of `Attribute` and the wrapper will check
482  /// for the target attribute to be of the expected class before calling the
483  /// callback.
484  template <typename T, typename A, typename FnT>
485  TypeAttributeConversionCallbackFn
486  wrapTypeAttributeConversion(FnT &&callback) const {
487  return [callback = std::forward<FnT>(callback)](
488  Type type, Attribute attr) -> AttributeConversionResult {
489  if (T derivedType = dyn_cast<T>(type)) {
490  if (A derivedAttr = dyn_cast_or_null<A>(attr))
491  return callback(derivedType, derivedAttr);
492  }
494  };
495  }
496 
497  /// Register a memory space conversion, clearing caches.
498  void
499  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
500  typeAttributeConversions.emplace_back(std::move(callback));
501  // Clear type conversions in case a memory space is lingering inside.
502  cachedDirectConversions.clear();
503  cachedMultiConversions.clear();
504  }
505 
506  /// The set of registered conversion functions.
507  SmallVector<ConversionCallbackFn, 4> conversions;
508 
509  /// The list of registered materialization functions.
510  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
511  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
512  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
513 
514  /// The list of registered type attribute conversion functions.
515  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
516 
517  /// A set of cached conversions to avoid recomputing in the common case.
518  /// Direct 1-1 conversions are the most common, so this cache stores the
519  /// successful 1-1 conversions as well as all failed conversions.
520  mutable DenseMap<Type, Type> cachedDirectConversions;
521  /// This cache stores the successful 1->N conversions, where N != 1.
522  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
523  /// A mutex used for cache access
524  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
525 };
526 
527 //===----------------------------------------------------------------------===//
528 // Conversion Patterns
529 //===----------------------------------------------------------------------===//
530 
531 /// Base class for the conversion patterns. This pattern class enables type
532 /// conversions, and other uses specific to the conversion framework. As such,
533 /// patterns of this type can only be used with the 'apply*' methods below.
535 public:
536  /// Hook for derived classes to implement rewriting. `op` is the (first)
537  /// operation matched by the pattern, `operands` is a list of the rewritten
538  /// operand values that are passed to `op`, `rewriter` can be used to emit the
539  /// new operations. This function should not fail. If some specific cases of
540  /// the operation are not supported, these cases should not be matched.
541  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
542  ConversionPatternRewriter &rewriter) const {
543  llvm_unreachable("unimplemented rewrite");
544  }
545  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
546  ConversionPatternRewriter &rewriter) const {
547  rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
548  }
549 
550  /// Hook for derived classes to implement combined matching and rewriting.
551  /// This overload supports only 1:1 replacements. The 1:N overload is called
552  /// by the driver. By default, it calls this 1:1 overload or reports a fatal
553  /// error if 1:N replacements were found.
554  virtual LogicalResult
556  ConversionPatternRewriter &rewriter) const {
557  if (failed(match(op)))
558  return failure();
559  rewrite(op, operands, rewriter);
560  return success();
561  }
562 
563  /// Hook for derived classes to implement combined matching and rewriting.
564  /// This overload supports 1:N replacements.
565  virtual LogicalResult
567  ConversionPatternRewriter &rewriter) const {
568  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
569  }
570 
571  /// Attempt to match and rewrite the IR root at the specified operation.
572  LogicalResult matchAndRewrite(Operation *op,
573  PatternRewriter &rewriter) const final;
574 
575  /// Return the type converter held by this pattern, or nullptr if the pattern
576  /// does not require type conversion.
577  const TypeConverter *getTypeConverter() const { return typeConverter; }
578 
579  template <typename ConverterTy>
580  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
581  const ConverterTy *>
583  return static_cast<const ConverterTy *>(typeConverter);
584  }
585 
586 protected:
587  /// See `RewritePattern::RewritePattern` for information on the other
588  /// available constructors.
589  using RewritePattern::RewritePattern;
590  /// Construct a conversion pattern with the given converter, and forward the
591  /// remaining arguments to RewritePattern.
592  template <typename... Args>
594  : RewritePattern(std::forward<Args>(args)...),
596 
597  /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
598  /// try to extract the single value of each range to construct a the inputs
599  /// for a 1:1 adaptor.
600  ///
601  /// This function produces a fatal error if at least one range has 0 or
602  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
605 
606 protected:
607  /// An optional type converter for use by this pattern.
608  const TypeConverter *typeConverter = nullptr;
609 
610 private:
612 };
613 
614 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
615 /// matching and rewriting against an instance of a derived operation class as
616 /// opposed to a raw Operation.
617 template <typename SourceOp>
619 public:
620  using OpAdaptor = typename SourceOp::Adaptor;
622  typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
623 
625  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
627  PatternBenefit benefit = 1)
628  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
629  context) {}
630 
631  /// Wrappers around the ConversionPattern methods that pass the derived op
632  /// type.
633  LogicalResult match(Operation *op) const final {
634  return match(cast<SourceOp>(op));
635  }
636  void rewrite(Operation *op, ArrayRef<Value> operands,
637  ConversionPatternRewriter &rewriter) const final {
638  auto sourceOp = cast<SourceOp>(op);
639  rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
640  }
642  ConversionPatternRewriter &rewriter) const final {
643  auto sourceOp = cast<SourceOp>(op);
644  rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
645  }
646  LogicalResult
648  ConversionPatternRewriter &rewriter) const final {
649  auto sourceOp = cast<SourceOp>(op);
650  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
651  }
652  LogicalResult
654  ConversionPatternRewriter &rewriter) const final {
655  auto sourceOp = cast<SourceOp>(op);
656  return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
657  rewriter);
658  }
659 
660  /// Rewrite and Match methods that operate on the SourceOp type. These must be
661  /// overridden by the derived pattern class.
662  virtual LogicalResult match(SourceOp op) const {
663  llvm_unreachable("must override match or matchAndRewrite");
664  }
665  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
666  ConversionPatternRewriter &rewriter) const {
667  llvm_unreachable("must override matchAndRewrite or a rewrite method");
668  }
669  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
670  ConversionPatternRewriter &rewriter) const {
671  SmallVector<Value> oneToOneOperands =
672  getOneToOneAdaptorOperands(adaptor.getOperands());
673  rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
674  }
675  virtual LogicalResult
676  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
677  ConversionPatternRewriter &rewriter) const {
678  if (failed(match(op)))
679  return failure();
680  rewrite(op, adaptor, rewriter);
681  return success();
682  }
683  virtual LogicalResult
684  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
685  ConversionPatternRewriter &rewriter) const {
686  SmallVector<Value> oneToOneOperands =
687  getOneToOneAdaptorOperands(adaptor.getOperands());
688  return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
689  }
690 
691 private:
693 };
694 
695 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
696 /// allows for matching and rewriting against an instance of an OpInterface
697 /// class as opposed to a raw Operation.
698 template <typename SourceOp>
700 public:
703  SourceOp::getInterfaceID(), benefit, context) {}
705  MLIRContext *context, PatternBenefit benefit = 1)
707  SourceOp::getInterfaceID(), benefit, context) {}
708 
709  /// Wrappers around the ConversionPattern methods that pass the derived op
710  /// type.
711  void rewrite(Operation *op, ArrayRef<Value> operands,
712  ConversionPatternRewriter &rewriter) const final {
713  rewrite(cast<SourceOp>(op), operands, rewriter);
714  }
716  ConversionPatternRewriter &rewriter) const final {
717  rewrite(cast<SourceOp>(op), operands, rewriter);
718  }
719  LogicalResult
721  ConversionPatternRewriter &rewriter) const final {
722  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
723  }
724  LogicalResult
726  ConversionPatternRewriter &rewriter) const final {
727  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
728  }
729 
730  /// Rewrite and Match methods that operate on the SourceOp type. These must be
731  /// overridden by the derived pattern class.
732  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
733  ConversionPatternRewriter &rewriter) const {
734  llvm_unreachable("must override matchAndRewrite or a rewrite method");
735  }
736  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
737  ConversionPatternRewriter &rewriter) const {
738  rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
739  }
740  virtual LogicalResult
741  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
742  ConversionPatternRewriter &rewriter) const {
743  if (failed(match(op)))
744  return failure();
745  rewrite(op, operands, rewriter);
746  return success();
747  }
748  virtual LogicalResult
749  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
750  ConversionPatternRewriter &rewriter) const {
751  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
752  }
753 
754 private:
756 };
757 
758 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
759 /// for matching and rewriting against instances of an operation that possess a
760 /// given trait.
761 template <template <typename> class TraitType>
763 public:
766  TypeID::get<TraitType>(), benefit, context) {}
768  MLIRContext *context, PatternBenefit benefit = 1)
770  TypeID::get<TraitType>(), benefit, context) {}
771 };
772 
773 /// Generic utility to convert op result types according to type converter
774 /// without knowing exact op type.
775 /// Clones existing op with new result types and returns it.
776 FailureOr<Operation *>
777 convertOpResultTypes(Operation *op, ValueRange operands,
778  const TypeConverter &converter,
779  ConversionPatternRewriter &rewriter);
780 
781 /// Add a pattern to the given pattern list to convert the signature of a
782 /// FunctionOpInterface op with the given type converter. This only supports
783 /// ops which use FunctionType to represent their type.
785  StringRef functionLikeOpName, RewritePatternSet &patterns,
786  const TypeConverter &converter);
787 
788 template <typename FuncOpT>
790  RewritePatternSet &patterns, const TypeConverter &converter) {
791  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
792  patterns, converter);
793 }
794 
796  RewritePatternSet &patterns, const TypeConverter &converter);
797 
798 //===----------------------------------------------------------------------===//
799 // Conversion PatternRewriter
800 //===----------------------------------------------------------------------===//
801 
802 namespace detail {
803 struct ConversionPatternRewriterImpl;
804 } // namespace detail
805 
806 /// This class implements a pattern rewriter for use with ConversionPatterns. It
807 /// extends the base PatternRewriter and provides special conversion specific
808 /// hooks.
810 public:
812 
813  /// Apply a signature conversion to given block. This replaces the block with
814  /// a new block containing the updated signature. The operations of the given
815  /// block are inlined into the newly-created block, which is returned.
816  ///
817  /// If no block argument types are changing, the original block will be
818  /// left in place and returned.
819  ///
820  /// A signature converison must be provided. (Type converters can construct
821  /// a signature conversion with `convertBlockSignature`.)
822  ///
823  /// Optionally, a type converter can be provided to build materializations.
824  /// Note: If no type converter was provided or the type converter does not
825  /// specify any suitable argument/target materialization rules, the dialect
826  /// conversion may fail to legalize unresolved materializations.
827  Block *
830  const TypeConverter *converter = nullptr);
831 
832  /// Apply a signature conversion to each block in the given region. This
833  /// replaces each block with a new block containing the updated signature. If
834  /// an updated signature would match the current signature, the respective
835  /// block is left in place as is. (See `applySignatureConversion` for
836  /// details.) The new entry block of the region is returned.
837  ///
838  /// SignatureConversions are computed with the specified type converter.
839  /// This function returns "failure" if the type converter failed to compute
840  /// a SignatureConversion for at least one block.
841  ///
842  /// Optionally, a special SignatureConversion can be specified for the entry
843  /// block. This is because the types of the entry block arguments are often
844  /// tied semantically to the operation.
845  FailureOr<Block *> convertRegionTypes(
846  Region *region, const TypeConverter &converter,
847  TypeConverter::SignatureConversion *entryConversion = nullptr);
848 
849  /// Replace all the uses of the block argument `from` with value `to`.
851 
852  /// Return the converted value of 'key' with a type defined by the type
853  /// converter of the currently executing pattern. Return nullptr in the case
854  /// of failure, the remapped value otherwise.
856 
857  /// Return the converted values that replace 'keys' with types defined by the
858  /// type converter of the currently executing pattern. Returns failure if the
859  /// remap failed, success otherwise.
860  LogicalResult getRemappedValues(ValueRange keys,
861  SmallVectorImpl<Value> &results);
862 
863  //===--------------------------------------------------------------------===//
864  // PatternRewriter Hooks
865  //===--------------------------------------------------------------------===//
866 
867  /// Indicate that the conversion rewriter can recover from rewrite failure.
868  /// Recovery is supported via rollback, allowing for continued processing of
869  /// patterns even if a failure is encountered during the rewrite step.
870  bool canRecoverFromRewriteFailure() const override { return true; }
871 
872  /// Replace the given operation with the new values. The number of op results
873  /// and replacement values must match. The types may differ: the dialect
874  /// conversion driver will reconcile any surviving type mismatches at the end
875  /// of the conversion process with source materializations. The given
876  /// operation is erased.
877  void replaceOp(Operation *op, ValueRange newValues) override;
878 
879  /// Replace the given operation with the results of the new op. The number of
880  /// op results must match. The types may differ: the dialect conversion
881  /// driver will reconcile any surviving type mismatches at the end of the
882  /// conversion process with source materializations. The original operation
883  /// is erased.
884  void replaceOp(Operation *op, Operation *newOp) override;
885 
886  /// Replace the given operation with the new value ranges. The number of op
887  /// results and value ranges must match. The given operation is erased.
889 
890  /// PatternRewriter hook for erasing a dead operation. The uses of this
891  /// operation *must* be made dead by the end of the conversion process,
892  /// otherwise an assert will be issued.
893  void eraseOp(Operation *op) override;
894 
895  /// PatternRewriter hook for erase all operations in a block. This is not yet
896  /// implemented for dialect conversion.
897  void eraseBlock(Block *block) override;
898 
899  /// PatternRewriter hook for inlining the ops of a block into another block.
900  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
901  ValueRange argValues = std::nullopt) override;
903 
904  /// PatternRewriter hook for updating the given operation in-place.
905  /// Note: These methods only track updates to the given operation itself,
906  /// and not nested regions. Updates to regions will still require notification
907  /// through other more specific hooks above.
908  void startOpModification(Operation *op) override;
909 
910  /// PatternRewriter hook for updating the given operation in-place.
911  void finalizeOpModification(Operation *op) override;
912 
913  /// PatternRewriter hook for updating the given operation in-place.
914  void cancelOpModification(Operation *op) override;
915 
916  /// Return a reference to the internal implementation.
918 
919 private:
920  // Allow OperationConverter to construct new rewriters.
921  friend struct OperationConverter;
922 
923  /// Conversion pattern rewriters must not be used outside of dialect
924  /// conversions. They apply some IR rewrites in a delayed fashion and could
925  /// bring the IR into an inconsistent state when used standalone.
927  const ConversionConfig &config);
928 
929  // Hide unsupported pattern rewriter API.
931 
932  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
933 };
934 
935 //===----------------------------------------------------------------------===//
936 // ConversionTarget
937 //===----------------------------------------------------------------------===//
938 
939 /// This class describes a specific conversion target.
941 public:
942  /// This enumeration corresponds to the specific action to take when
943  /// considering an operation legal for this conversion target.
944  enum class LegalizationAction {
945  /// The target supports this operation.
946  Legal,
947 
948  /// This operation has dynamic legalization constraints that must be checked
949  /// by the target.
950  Dynamic,
951 
952  /// The target explicitly does not support this operation.
953  Illegal,
954  };
955 
956  /// A structure containing additional information describing a specific legal
957  /// operation instance.
958  struct LegalOpDetails {
959  /// A flag that indicates if this operation is 'recursively' legal. This
960  /// means that if an operation is legal, either statically or dynamically,
961  /// all of the operations nested within are also considered legal.
962  bool isRecursivelyLegal = false;
963  };
964 
965  /// The signature of the callback used to determine if an operation is
966  /// dynamically legal on the target.
968  std::function<std::optional<bool>(Operation *)>;
969 
970  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
971  virtual ~ConversionTarget() = default;
972 
973  //===--------------------------------------------------------------------===//
974  // Legality Registration
975  //===--------------------------------------------------------------------===//
976 
977  /// Register a legality action for the given operation.
979  template <typename OpT>
981  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
982  }
983 
984  /// Register the given operations as legal.
987  }
988  template <typename OpT>
989  void addLegalOp() {
990  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
991  }
992  template <typename OpT, typename OpT2, typename... OpTs>
993  void addLegalOp() {
994  addLegalOp<OpT>();
995  addLegalOp<OpT2, OpTs...>();
996  }
997 
998  /// Register the given operation as dynamically legal and set the dynamic
999  /// legalization callback to the one provided.
1001  const DynamicLegalityCallbackFn &callback) {
1003  setLegalityCallback(op, callback);
1004  }
1005  template <typename OpT>
1007  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
1008  callback);
1009  }
1010  template <typename OpT, typename OpT2, typename... OpTs>
1012  addDynamicallyLegalOp<OpT>(callback);
1013  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
1014  }
1015  template <typename OpT, class Callable>
1016  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1017  addDynamicallyLegalOp(Callable &&callback) {
1018  addDynamicallyLegalOp<OpT>(
1019  [=](Operation *op) { return callback(cast<OpT>(op)); });
1020  }
1021 
1022  /// Register the given operation as illegal, i.e. this operation is known to
1023  /// not be supported by this target.
1026  }
1027  template <typename OpT>
1028  void addIllegalOp() {
1029  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
1030  }
1031  template <typename OpT, typename OpT2, typename... OpTs>
1032  void addIllegalOp() {
1033  addIllegalOp<OpT>();
1034  addIllegalOp<OpT2, OpTs...>();
1035  }
1036 
1037  /// Mark an operation, that *must* have either been set as `Legal` or
1038  /// `DynamicallyLegal`, as being recursively legal. This means that in
1039  /// addition to the operation itself, all of the operations nested within are
1040  /// also considered legal. An optional dynamic legality callback may be
1041  /// provided to mark subsets of legal instances as recursively legal.
1043  const DynamicLegalityCallbackFn &callback);
1044  template <typename OpT>
1046  OperationName opName(OpT::getOperationName(), &ctx);
1047  markOpRecursivelyLegal(opName, callback);
1048  }
1049  template <typename OpT, typename OpT2, typename... OpTs>
1051  markOpRecursivelyLegal<OpT>(callback);
1052  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
1053  }
1054  template <typename OpT, class Callable>
1055  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1056  markOpRecursivelyLegal(Callable &&callback) {
1057  markOpRecursivelyLegal<OpT>(
1058  [=](Operation *op) { return callback(cast<OpT>(op)); });
1059  }
1060 
1061  /// Register a legality action for the given dialects.
1062  void setDialectAction(ArrayRef<StringRef> dialectNames,
1063  LegalizationAction action);
1064 
1065  /// Register the operations of the given dialects as legal.
1066  template <typename... Names>
1067  void addLegalDialect(StringRef name, Names... names) {
1068  SmallVector<StringRef, 2> dialectNames({name, names...});
1070  }
1071  template <typename... Args>
1073  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1075  }
1076 
1077  /// Register the operations of the given dialects as dynamically legal, i.e.
1078  /// requiring custom handling by the callback.
1079  template <typename... Names>
1081  StringRef name, Names... names) {
1082  SmallVector<StringRef, 2> dialectNames({name, names...});
1084  setLegalityCallback(dialectNames, callback);
1085  }
1086  template <typename... Args>
1088  addDynamicallyLegalDialect(std::move(callback),
1089  Args::getDialectNamespace()...);
1090  }
1091 
1092  /// Register unknown operations as dynamically legal. For operations(and
1093  /// dialects) that do not have a set legalization action, treat them as
1094  /// dynamically legal and invoke the given callback.
1096  setLegalityCallback(fn);
1097  }
1098 
1099  /// Register the operations of the given dialects as illegal, i.e.
1100  /// operations of this dialect are not supported by the target.
1101  template <typename... Names>
1102  void addIllegalDialect(StringRef name, Names... names) {
1103  SmallVector<StringRef, 2> dialectNames({name, names...});
1105  }
1106  template <typename... Args>
1108  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1110  }
1111 
1112  //===--------------------------------------------------------------------===//
1113  // Legality Querying
1114  //===--------------------------------------------------------------------===//
1115 
1116  /// Get the legality action for the given operation.
1117  std::optional<LegalizationAction> getOpAction(OperationName op) const;
1118 
1119  /// If the given operation instance is legal on this target, a structure
1120  /// containing legality information is returned. If the operation is not
1121  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1122  /// legality wasn't registered by user or dynamic legality callbacks returned
1123  /// None.
1124  ///
1125  /// Note: Legality is actually a 4-state: Legal(recursive=true),
1126  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1127  /// either as Legal or Illegal depending on context.
1128  std::optional<LegalOpDetails> isLegal(Operation *op) const;
1129 
1130  /// Returns true is operation instance is illegal on this target. Returns
1131  /// false if operation is legal, operation legality wasn't registered by user
1132  /// or dynamic legality callbacks returned None.
1133  bool isIllegal(Operation *op) const;
1134 
1135 private:
1136  /// Set the dynamic legality callback for the given operation.
1137  void setLegalityCallback(OperationName name,
1138  const DynamicLegalityCallbackFn &callback);
1139 
1140  /// Set the dynamic legality callback for the given dialects.
1141  void setLegalityCallback(ArrayRef<StringRef> dialects,
1142  const DynamicLegalityCallbackFn &callback);
1143 
1144  /// Set the dynamic legality callback for the unknown ops.
1145  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1146 
1147  /// The set of information that configures the legalization of an operation.
1148  struct LegalizationInfo {
1149  /// The legality action this operation was given.
1151 
1152  /// If some legal instances of this operation may also be recursively legal.
1153  bool isRecursivelyLegal = false;
1154 
1155  /// The legality callback if this operation is dynamically legal.
1156  DynamicLegalityCallbackFn legalityFn;
1157  };
1158 
1159  /// Get the legalization information for the given operation.
1160  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1161 
1162  /// A deterministic mapping of operation name and its respective legality
1163  /// information.
1164  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1165 
1166  /// A set of legality callbacks for given operation names that are used to
1167  /// check if an operation instance is recursively legal.
1168  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1169 
1170  /// A deterministic mapping of dialect name to the specific legality action to
1171  /// take.
1172  llvm::StringMap<LegalizationAction> legalDialects;
1173 
1174  /// A set of dynamic legality callbacks for given dialect names.
1175  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1176 
1177  /// An optional legality callback for unknown operations.
1178  DynamicLegalityCallbackFn unknownLegalityFn;
1179 
1180  /// The current context this target applies to.
1181  MLIRContext &ctx;
1182 };
1183 
1184 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1185 //===----------------------------------------------------------------------===//
1186 // PDL Configuration
1187 //===----------------------------------------------------------------------===//
1188 
1189 /// A PDL configuration that is used to supported dialect conversion
1190 /// functionality.
1192  : public PDLPatternConfigBase<PDLConversionConfig> {
1193 public:
1194  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1195  ~PDLConversionConfig() final = default;
1196 
1197  /// Return the type converter used by this configuration, which may be nullptr
1198  /// if no type conversions are expected.
1199  const TypeConverter *getTypeConverter() const { return converter; }
1200 
1201  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1202  /// pattern.
1203  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1204  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1205 
1206 private:
1207  /// An optional type converter to use for the pattern.
1208  const TypeConverter *converter;
1209 };
1210 
1211 /// Register the dialect conversion PDL functions with the given pattern set.
1212 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1213 
1214 #else
1215 
1216 // Stubs for when PDL in rewriting is not enabled.
1217 
1218 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1219 
1220 class PDLConversionConfig final {
1221 public:
1222  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1223 };
1224 
1225 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1226 
1227 //===----------------------------------------------------------------------===//
1228 // ConversionConfig
1229 //===----------------------------------------------------------------------===//
1230 
1231 /// Dialect conversion configuration.
1233  /// An optional callback used to notify about match failure diagnostics during
1234  /// the conversion. Diagnostics reported to this callback may only be
1235  /// available in debug mode.
1237 
1238  /// Partial conversion only. All operations that are found not to be
1239  /// legalizable are placed in this set. (Note that if there is an op
1240  /// explicitly marked as illegal, the conversion terminates and the set will
1241  /// not necessarily be complete.)
1243 
1244  /// Analysis conversion only. All operations that are found to be legalizable
1245  /// are placed in this set. Note that no actual rewrites are applied to the
1246  /// IR during an analysis conversion and only pre-existing operations are
1247  /// added to the set.
1249 
1250  /// An optional listener that is notified about all IR modifications in case
1251  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1252  /// modifications are visible (i.e., they were all rolled back), or if the
1253  /// dialect conversion is an "analysis conversion", no notifications are
1254  /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1255  ///
1256  /// Note: Notifications are sent in a delayed fashion, when the dialect
1257  /// conversion is guaranteed to succeed. At that point, some IR modifications
1258  /// may already have been materialized. Consequently, operations/blocks that
1259  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1260  /// guaranteed to be valid pointers and accessing op names is allowed. But
1261  /// there are no guarantees about the state of ops/blocks at the time that a
1262  /// callback is triggered.)
1263  ///
1264  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1265  /// and inserted, and later moved to another block. (Moving ops also triggers
1266  /// "notifyOperationInserted".)
1267  ///
1268  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1269  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1270  ///
1271  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1272  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1273  /// immediately performed by the dialect conversion (and rolled back upon
1274  /// failure).
1275  //
1276  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1277  // callback, the previous region/block is provided to the callback, but not
1278  // the iterator pointing to the exact location within the region/block. That
1279  // is because these notifications are sent with a delay (after the IR has
1280  // already been modified) and iterators into past IR state cannot be
1281  // represented at the moment.
1283 
1284  /// If set to "true", the dialect conversion attempts to build source/target
1285  /// materializations through the type converter API in lieu of
1286  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1287  /// at least one materialization could not be built.
1288  ///
1289  /// If set to "false", the dialect conversion does not build any custom
1290  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1291  /// ops to ensure that the resulting IR is valid.
1293 };
1294 
1295 //===----------------------------------------------------------------------===//
1296 // Reconcile Unrealized Casts
1297 //===----------------------------------------------------------------------===//
1298 
1299 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1300 /// left-over ops in `remainingCastOps` (if provided).
1301 ///
1302 /// This function processes cast ops in a worklist-driven fashion. For each
1303 /// cast op, if the chain of input casts eventually reaches a cast op where the
1304 /// input types match the output types of the matched op, replace the matched
1305 /// op with the inputs.
1306 ///
1307 /// Example:
1308 /// %1 = unrealized_conversion_cast %0 : !A to !B
1309 /// %2 = unrealized_conversion_cast %1 : !B to !C
1310 /// %3 = unrealized_conversion_cast %2 : !C to !A
1311 ///
1312 /// In the above example, %0 can be used instead of %3 and all cast ops are
1313 /// folded away.
1316  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1317 
1318 //===----------------------------------------------------------------------===//
1319 // Op Conversion Entry Points
1320 //===----------------------------------------------------------------------===//
1321 
1322 /// Below we define several entry points for operation conversion. It is
1323 /// important to note that the patterns provided to the conversion framework may
1324 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1325 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1326 /// the use of the PatternRewriter.
1327 
1328 /// Apply a partial conversion on the given operations and all nested
1329 /// operations. This method converts as many operations to the target as
1330 /// possible, ignoring operations that failed to legalize. This method only
1331 /// returns failure if there ops explicitly marked as illegal.
1332 LogicalResult
1334  const ConversionTarget &target,
1337 LogicalResult
1341 
1342 /// Apply a complete conversion on the given operations, and all nested
1343 /// operations. This method returns failure if the conversion of any operation
1344 /// fails, or if there are unreachable blocks in any of the regions nested
1345 /// within 'ops'.
1346 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1347  const ConversionTarget &target,
1350 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1353 
1354 /// Apply an analysis conversion on the given operations, and all nested
1355 /// operations. This method analyzes which operations would be successfully
1356 /// converted to the target if a conversion was applied. All operations that
1357 /// were found to be legalizable to the given 'target' are placed within the
1358 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1359 /// applied to the operations on success. This method only returns failure if
1360 /// there are unreachable blocks in any of the regions nested within 'ops'.
1361 LogicalResult
1365 LogicalResult
1369 } // namespace mlir
1370 
1371 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual void rewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
typename SourceOp::Adaptor OpAdaptor
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
void rewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
virtual void rewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
This struct represents a range of new types or a single value that remaps an existing signature input...