MLIR  20.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  argumentMaterializations(other.argumentMaterializations),
49  sourceMaterializations(other.sourceMaterializations),
50  targetMaterializations(other.targetMaterializations),
51  typeAttributeConversions(other.typeAttributeConversions) {}
53  conversions = other.conversions;
54  argumentMaterializations = other.argumentMaterializations;
55  sourceMaterializations = other.sourceMaterializations;
56  targetMaterializations = other.targetMaterializations;
57  typeAttributeConversions = other.typeAttributeConversions;
58  return *this;
59  }
60 
61  /// This class provides all of the information necessary to convert a type
62  /// signature.
64  public:
65  SignatureConversion(unsigned numOrigInputs)
66  : remappedInputs(numOrigInputs) {}
67 
68  /// This struct represents a range of new types or a single value that
69  /// remaps an existing signature input.
70  struct InputMapping {
71  size_t inputNo, size;
73  };
74 
75  /// Return the argument types for the new signature.
76  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
77 
78  /// Get the input mapping for the given argument.
79  std::optional<InputMapping> getInputMapping(unsigned input) const {
80  return remappedInputs[input];
81  }
82 
83  //===------------------------------------------------------------------===//
84  // Conversion Hooks
85  //===------------------------------------------------------------------===//
86 
87  /// Remap an input of the original signature with a new set of types. The
88  /// new types are appended to the new signature conversion.
89  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
90 
91  /// Append new input types to the signature conversion, this should only be
92  /// used if the new types are not intended to remap an existing input.
93  void addInputs(ArrayRef<Type> types);
94 
95  /// Remap an input of the original signature to another `replacement`
96  /// value. This drops the original argument.
97  void remapInput(unsigned origInputNo, Value replacement);
98 
99  private:
100  /// Remap an input of the original signature with a range of types in the
101  /// new signature.
102  void remapInput(unsigned origInputNo, unsigned newInputNo,
103  unsigned newInputCount = 1);
104 
105  /// The remapping information for each of the original arguments.
106  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
107 
108  /// The set of new argument types.
109  SmallVector<Type, 4> argTypes;
110  };
111 
112  /// The general result of a type attribute conversion callback, allowing
113  /// for early termination. The default constructor creates the na case.
115  public:
116  constexpr AttributeConversionResult() : impl() {}
117  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
118 
120  static AttributeConversionResult na();
122 
123  bool hasResult() const;
124  bool isNa() const;
125  bool isAbort() const;
126 
127  Attribute getResult() const;
128 
129  private:
130  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
131 
132  llvm::PointerIntPair<Attribute, 2> impl;
133  // Note that na is 0 so that we can use PointerIntPair's default
134  // constructor.
135  static constexpr unsigned naTag = 0;
136  static constexpr unsigned resultTag = 1;
137  static constexpr unsigned abortTag = 2;
138  };
139 
140  /// Register a conversion function. A conversion function must be convertible
141  /// to any of the following forms (where `T` is a class derived from `Type`):
142  ///
143  /// * std::optional<Type>(T)
144  /// - This form represents a 1-1 type conversion. It should return nullptr
145  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
146  /// the converter is allowed to try another conversion function to
147  /// perform the conversion.
148  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
149  /// - This form represents a 1-N type conversion. It should return
150  /// `failure` or `std::nullopt` to signify a failed conversion. If the
151  /// new set of types is empty, the type is removed and any usages of the
152  /// existing value are expected to be removed during conversion. If
153  /// `std::nullopt` is returned, the converter is allowed to try another
154  /// conversion function to perform the conversion.
155  ///
156  /// Note: When attempting to convert a type, e.g. via 'convertType', the
157  /// mostly recently added conversions will be invoked first.
158  template <typename FnT, typename T = typename llvm::function_traits<
159  std::decay_t<FnT>>::template arg_t<0>>
160  void addConversion(FnT &&callback) {
161  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
162  }
163 
164  /// All of the following materializations require function objects that are
165  /// convertible to the following form:
166  /// `Value(OpBuilder &, T, ValueRange, Location)`,
167  /// where `T` is any subclass of `Type`. This function is responsible for
168  /// creating an operation, using the OpBuilder and Location provided, that
169  /// "casts" a range of values into a single value of the given type `T`. It
170  /// must return a Value of the type `T` on success and `nullptr` if
171  /// it failed but other materialization should be attempted. Materialization
172  /// functions must be provided when a type conversion may persist after the
173  /// conversion has finished.
174  ///
175  /// Note: Target materializations may optionally accept an additional Type
176  /// parameter, which is the original type of the SSA value. Furthermore, `T`
177  /// can be a TypeRange; in that case, the function must return a
178  /// SmallVector<Value>.
179 
180  /// This method registers a materialization that will be called when
181  /// converting (potentially multiple) block arguments that were the result of
182  /// a signature conversion of a single block argument, to a single SSA value
183  /// with the old block argument type.
184  template <typename FnT, typename T = typename llvm::function_traits<
185  std::decay_t<FnT>>::template arg_t<1>>
186  void addArgumentMaterialization(FnT &&callback) {
187  argumentMaterializations.emplace_back(
188  wrapMaterialization<T>(std::forward<FnT>(callback)));
189  }
190 
191  /// This method registers a materialization that will be called when
192  /// converting a replacement value back to its original source type.
193  /// This is used when some uses of the original value persist beyond the main
194  /// conversion.
195  template <typename FnT, typename T = typename llvm::function_traits<
196  std::decay_t<FnT>>::template arg_t<1>>
197  void addSourceMaterialization(FnT &&callback) {
198  sourceMaterializations.emplace_back(
199  wrapMaterialization<T>(std::forward<FnT>(callback)));
200  }
201 
202  /// This method registers a materialization that will be called when
203  /// converting a value to a target type according to a pattern's type
204  /// converter.
205  ///
206  /// Note: Target materializations can optionally inspect the "original"
207  /// type. This type may be different from the type of the input value.
208  /// For example, let's assume that a conversion pattern "P1" replaced an SSA
209  /// value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion
210  /// pattern "P2" matches an op that has "v1" as an operand. Let's furthermore
211  /// assume that "P2" determines that the converted target type of "t1" is
212  /// "t3", which may be different from "t2". In this example, the target
213  /// materialization will be invoked with: outputType = "t3", inputs = "v2",
214  /// originalType = "t1". Note that the original type "t1" cannot be recovered
215  /// from just "t3" and "v2"; that's why the originalType parameter exists.
216  ///
217  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
218  /// that case the materialization produces a SmallVector<Value>.
219  template <typename FnT, typename T = typename llvm::function_traits<
220  std::decay_t<FnT>>::template arg_t<1>>
221  void addTargetMaterialization(FnT &&callback) {
222  targetMaterializations.emplace_back(
223  wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
224  }
225 
226  /// Register a conversion function for attributes within types. Type
227  /// converters may call this function in order to allow hoking into the
228  /// translation of attributes that exist within types. For example, a type
229  /// converter for the `memref` type could use these conversions to convert
230  /// memory spaces or layouts in an extensible way.
231  ///
232  /// The conversion functions take a non-null Type or subclass of Type and a
233  /// non-null Attribute (or subclass of Attribute), and returns a
234  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
235  /// which may be `nullptr`, representing the conversion's success,
236  /// `AttributeConversionResult::na()` (the default empty value), indicating
237  /// that the conversion function did not apply and that further conversion
238  /// functions should be checked, or `AttributeConversionResult::abort()`
239  /// indicating that the conversion process should be aborted.
240  ///
241  /// Registered conversion functions are callled in the reverse of the order in
242  /// which they were registered.
243  template <
244  typename FnT,
245  typename T =
246  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
247  typename A =
248  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
249  void addTypeAttributeConversion(FnT &&callback) {
250  registerTypeAttributeConversion(
251  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
252  }
253 
254  /// Convert the given type. This function should return failure if no valid
255  /// conversion exists, success otherwise. If the new set of types is empty,
256  /// the type is removed and any usages of the existing value are expected to
257  /// be removed during conversion.
258  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
259 
260  /// This hook simplifies defining 1-1 type conversions. This function returns
261  /// the type to convert to on success, and a null type on failure.
262  Type convertType(Type t) const;
263 
264  /// Attempts a 1-1 type conversion, expecting the result type to be
265  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
266  /// and a null type on conversion or cast failure.
267  template <typename TargetType>
268  TargetType convertType(Type t) const {
269  return dyn_cast_or_null<TargetType>(convertType(t));
270  }
271 
272  /// Convert the given set of types, filling 'results' as necessary. This
273  /// returns failure if the conversion of any of the types fails, success
274  /// otherwise.
275  LogicalResult convertTypes(TypeRange types,
276  SmallVectorImpl<Type> &results) const;
277 
278  /// Return true if the given type is legal for this type converter, i.e. the
279  /// type converts to itself.
280  bool isLegal(Type type) const;
281 
282  /// Return true if all of the given types are legal for this type converter.
283  template <typename RangeT>
284  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
285  !std::is_convertible<RangeT, Operation *>::value,
286  bool>
287  isLegal(RangeT &&range) const {
288  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
289  }
290  /// Return true if the given operation has legal operand and result types.
291  bool isLegal(Operation *op) const;
292 
293  /// Return true if the types of block arguments within the region are legal.
294  bool isLegal(Region *region) const;
295 
296  /// Return true if the inputs and outputs of the given function type are
297  /// legal.
298  bool isSignatureLegal(FunctionType ty) const;
299 
300  /// This method allows for converting a specific argument of a signature. It
301  /// takes as inputs the original argument input number, type.
302  /// On success, it populates 'result' with any new mappings.
303  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
304  SignatureConversion &result) const;
305  LogicalResult convertSignatureArgs(TypeRange types,
306  SignatureConversion &result,
307  unsigned origInputOffset = 0) const;
308 
309  /// This function converts the type signature of the given block, by invoking
310  /// 'convertSignatureArg' for each argument. This function should return a
311  /// valid conversion for the signature on success, std::nullopt otherwise.
312  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
313 
314  /// Materialize a conversion from a set of types into one result type by
315  /// generating a cast sequence of some kind. See the respective
316  /// `add*Materialization` for more information on the context for these
317  /// methods.
319  Type resultType, ValueRange inputs) const;
321  Type resultType, ValueRange inputs) const;
323  Type resultType, ValueRange inputs,
324  Type originalType = {}) const;
325  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
326  Location loc,
327  TypeRange resultType,
328  ValueRange inputs,
329  Type originalType = {}) const;
330 
331  /// Convert an attribute present `attr` from within the type `type` using
332  /// the registered conversion functions. If no applicable conversion has been
333  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
334  /// is a valid return value for this function.
335  std::optional<Attribute> convertTypeAttribute(Type type,
336  Attribute attr) const;
337 
338 private:
339  /// The signature of the callback used to convert a type. If the new set of
340  /// types is empty, the type is removed and any usages of the existing value
341  /// are expected to be removed during conversion.
342  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
343  Type, SmallVectorImpl<Type> &)>;
344 
345  /// The signature of the callback used to materialize a source/argument
346  /// conversion.
347  ///
348  /// Arguments: builder, result type, inputs, location
349  using MaterializationCallbackFn =
350  std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
351 
352  /// The signature of the callback used to materialize a target conversion.
353  ///
354  /// Arguments: builder, result types, inputs, location, original type
355  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
356  OpBuilder &, TypeRange, ValueRange, Location, Type)>;
357 
358  /// The signature of the callback used to convert a type attribute.
359  using TypeAttributeConversionCallbackFn =
360  std::function<AttributeConversionResult(Type, Attribute)>;
361 
362  /// Generate a wrapper for the given callback. This allows for accepting
363  /// different callback forms, that all compose into a single version.
364  /// With callback of form: `std::optional<Type>(T)`
365  template <typename T, typename FnT>
366  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
367  wrapCallback(FnT &&callback) const {
368  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
369  T type, SmallVectorImpl<Type> &results) {
370  if (std::optional<Type> resultOpt = callback(type)) {
371  bool wasSuccess = static_cast<bool>(*resultOpt);
372  if (wasSuccess)
373  results.push_back(*resultOpt);
374  return std::optional<LogicalResult>(success(wasSuccess));
375  }
376  return std::optional<LogicalResult>();
377  });
378  }
379  /// With callback of form: `std::optional<LogicalResult>(
380  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
381  template <typename T, typename FnT>
382  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
383  ConversionCallbackFn>
384  wrapCallback(FnT &&callback) const {
385  return [callback = std::forward<FnT>(callback)](
386  Type type,
387  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
388  T derivedType = dyn_cast<T>(type);
389  if (!derivedType)
390  return std::nullopt;
391  return callback(derivedType, results);
392  };
393  }
394 
395  /// Register a type conversion.
396  void registerConversion(ConversionCallbackFn callback) {
397  conversions.emplace_back(std::move(callback));
398  cachedDirectConversions.clear();
399  cachedMultiConversions.clear();
400  }
401 
402  /// Generate a wrapper for the given argument/source materialization
403  /// callback. The callback may take any subclass of `Type` and the
404  /// wrapper will check for the target type to be of the expected class
405  /// before calling the callback.
406  template <typename T, typename FnT>
407  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
408  return [callback = std::forward<FnT>(callback)](
409  OpBuilder &builder, Type resultType, ValueRange inputs,
410  Location loc) -> Value {
411  if (T derivedType = dyn_cast<T>(resultType))
412  return callback(builder, derivedType, inputs, loc);
413  return Value();
414  };
415  }
416 
417  /// Generate a wrapper for the given target materialization callback.
418  /// The callback may take any subclass of `Type` and the wrapper will check
419  /// for the target type to be of the expected class before calling the
420  /// callback.
421  ///
422  /// With callback of form:
423  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
424  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
425  template <typename T, typename FnT>
426  std::enable_if_t<
427  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
428  TargetMaterializationCallbackFn>
429  wrapTargetMaterialization(FnT &&callback) const {
430  return [callback = std::forward<FnT>(callback)](
431  OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
432  Location loc, Type originalType) -> SmallVector<Value> {
433  SmallVector<Value> result;
434  if constexpr (std::is_same<T, TypeRange>::value) {
435  // This is a 1:N target materialization. Return the produces values
436  // directly.
437  result = callback(builder, resultTypes, inputs, loc, originalType);
438  } else if constexpr (std::is_assignable<Type, T>::value) {
439  // This is a 1:1 target materialization. Invoke the callback only if a
440  // single SSA value is requested.
441  if (resultTypes.size() == 1) {
442  // Invoke the callback only if the type class of the callback matches
443  // the requested result type.
444  if (T derivedType = dyn_cast<T>(resultTypes.front())) {
445  // 1:1 materializations produce single values, but we store 1:N
446  // target materialization functions in the type converter. Wrap the
447  // result value in a SmallVector<Value>.
448  Value val =
449  callback(builder, derivedType, inputs, loc, originalType);
450  if (val)
451  result.push_back(val);
452  }
453  }
454  } else {
455  static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
456  }
457  return result;
458  };
459  }
460  /// With callback of form:
461  /// - Value(OpBuilder &, T, ValueRange, Location)
462  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
463  template <typename T, typename FnT>
464  std::enable_if_t<
465  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
466  TargetMaterializationCallbackFn>
467  wrapTargetMaterialization(FnT &&callback) const {
468  return wrapTargetMaterialization<T>(
469  [callback = std::forward<FnT>(callback)](
470  OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
471  Type originalType) {
472  return callback(builder, resultTypes, inputs, loc);
473  });
474  }
475 
476  /// Generate a wrapper for the given memory space conversion callback. The
477  /// callback may take any subclass of `Attribute` and the wrapper will check
478  /// for the target attribute to be of the expected class before calling the
479  /// callback.
480  template <typename T, typename A, typename FnT>
481  TypeAttributeConversionCallbackFn
482  wrapTypeAttributeConversion(FnT &&callback) const {
483  return [callback = std::forward<FnT>(callback)](
484  Type type, Attribute attr) -> AttributeConversionResult {
485  if (T derivedType = dyn_cast<T>(type)) {
486  if (A derivedAttr = dyn_cast_or_null<A>(attr))
487  return callback(derivedType, derivedAttr);
488  }
490  };
491  }
492 
493  /// Register a memory space conversion, clearing caches.
494  void
495  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
496  typeAttributeConversions.emplace_back(std::move(callback));
497  // Clear type conversions in case a memory space is lingering inside.
498  cachedDirectConversions.clear();
499  cachedMultiConversions.clear();
500  }
501 
502  /// The set of registered conversion functions.
503  SmallVector<ConversionCallbackFn, 4> conversions;
504 
505  /// The list of registered materialization functions.
506  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
507  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
508  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
509 
510  /// The list of registered type attribute conversion functions.
511  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
512 
513  /// A set of cached conversions to avoid recomputing in the common case.
514  /// Direct 1-1 conversions are the most common, so this cache stores the
515  /// successful 1-1 conversions as well as all failed conversions.
516  mutable DenseMap<Type, Type> cachedDirectConversions;
517  /// This cache stores the successful 1->N conversions, where N != 1.
518  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
519  /// A mutex used for cache access
520  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
521 };
522 
523 //===----------------------------------------------------------------------===//
524 // Conversion Patterns
525 //===----------------------------------------------------------------------===//
526 
527 /// Base class for the conversion patterns. This pattern class enables type
528 /// conversions, and other uses specific to the conversion framework. As such,
529 /// patterns of this type can only be used with the 'apply*' methods below.
531 public:
532  /// Hook for derived classes to implement rewriting. `op` is the (first)
533  /// operation matched by the pattern, `operands` is a list of the rewritten
534  /// operand values that are passed to `op`, `rewriter` can be used to emit the
535  /// new operations. This function should not fail. If some specific cases of
536  /// the operation are not supported, these cases should not be matched.
537  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
538  ConversionPatternRewriter &rewriter) const {
539  llvm_unreachable("unimplemented rewrite");
540  }
541  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
542  ConversionPatternRewriter &rewriter) const {
543  rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
544  }
545 
546  /// Hook for derived classes to implement combined matching and rewriting.
547  /// This overload supports only 1:1 replacements. The 1:N overload is called
548  /// by the driver. By default, it calls this 1:1 overload or reports a fatal
549  /// error if 1:N replacements were found.
550  virtual LogicalResult
552  ConversionPatternRewriter &rewriter) const {
553  if (failed(match(op)))
554  return failure();
555  rewrite(op, operands, rewriter);
556  return success();
557  }
558 
559  /// Hook for derived classes to implement combined matching and rewriting.
560  /// This overload supports 1:N replacements.
561  virtual LogicalResult
563  ConversionPatternRewriter &rewriter) const {
564  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
565  }
566 
567  /// Attempt to match and rewrite the IR root at the specified operation.
568  LogicalResult matchAndRewrite(Operation *op,
569  PatternRewriter &rewriter) const final;
570 
571  /// Return the type converter held by this pattern, or nullptr if the pattern
572  /// does not require type conversion.
573  const TypeConverter *getTypeConverter() const { return typeConverter; }
574 
575  template <typename ConverterTy>
576  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
577  const ConverterTy *>
579  return static_cast<const ConverterTy *>(typeConverter);
580  }
581 
582 protected:
583  /// See `RewritePattern::RewritePattern` for information on the other
584  /// available constructors.
585  using RewritePattern::RewritePattern;
586  /// Construct a conversion pattern with the given converter, and forward the
587  /// remaining arguments to RewritePattern.
588  template <typename... Args>
590  : RewritePattern(std::forward<Args>(args)...),
592 
593  /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
594  /// try to extract the single value of each range to construct a the inputs
595  /// for a 1:1 adaptor.
596  ///
597  /// This function produces a fatal error if at least one range has 0 or
598  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
601 
602 protected:
603  /// An optional type converter for use by this pattern.
604  const TypeConverter *typeConverter = nullptr;
605 
606 private:
608 };
609 
610 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
611 /// matching and rewriting against an instance of a derived operation class as
612 /// opposed to a raw Operation.
613 template <typename SourceOp>
615 public:
616  using OpAdaptor = typename SourceOp::Adaptor;
618  typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
619 
621  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
623  PatternBenefit benefit = 1)
624  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
625  context) {}
626 
627  /// Wrappers around the ConversionPattern methods that pass the derived op
628  /// type.
629  LogicalResult match(Operation *op) const final {
630  return match(cast<SourceOp>(op));
631  }
632  void rewrite(Operation *op, ArrayRef<Value> operands,
633  ConversionPatternRewriter &rewriter) const final {
634  auto sourceOp = cast<SourceOp>(op);
635  rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
636  }
638  ConversionPatternRewriter &rewriter) const final {
639  auto sourceOp = cast<SourceOp>(op);
640  rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
641  }
642  LogicalResult
644  ConversionPatternRewriter &rewriter) const final {
645  auto sourceOp = cast<SourceOp>(op);
646  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
647  }
648  LogicalResult
650  ConversionPatternRewriter &rewriter) const final {
651  auto sourceOp = cast<SourceOp>(op);
652  return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
653  rewriter);
654  }
655 
656  /// Rewrite and Match methods that operate on the SourceOp type. These must be
657  /// overridden by the derived pattern class.
658  virtual LogicalResult match(SourceOp op) const {
659  llvm_unreachable("must override match or matchAndRewrite");
660  }
661  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
662  ConversionPatternRewriter &rewriter) const {
663  llvm_unreachable("must override matchAndRewrite or a rewrite method");
664  }
665  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
666  ConversionPatternRewriter &rewriter) const {
667  SmallVector<Value> oneToOneOperands =
668  getOneToOneAdaptorOperands(adaptor.getOperands());
669  rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
670  }
671  virtual LogicalResult
672  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
673  ConversionPatternRewriter &rewriter) const {
674  if (failed(match(op)))
675  return failure();
676  rewrite(op, adaptor, rewriter);
677  return success();
678  }
679  virtual LogicalResult
680  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
681  ConversionPatternRewriter &rewriter) const {
682  SmallVector<Value> oneToOneOperands =
683  getOneToOneAdaptorOperands(adaptor.getOperands());
684  return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
685  }
686 
687 private:
689 };
690 
691 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
692 /// allows for matching and rewriting against an instance of an OpInterface
693 /// class as opposed to a raw Operation.
694 template <typename SourceOp>
696 public:
699  SourceOp::getInterfaceID(), benefit, context) {}
701  MLIRContext *context, PatternBenefit benefit = 1)
703  SourceOp::getInterfaceID(), benefit, context) {}
704 
705  /// Wrappers around the ConversionPattern methods that pass the derived op
706  /// type.
707  void rewrite(Operation *op, ArrayRef<Value> operands,
708  ConversionPatternRewriter &rewriter) const final {
709  rewrite(cast<SourceOp>(op), operands, rewriter);
710  }
712  ConversionPatternRewriter &rewriter) const final {
713  rewrite(cast<SourceOp>(op), operands, rewriter);
714  }
715  LogicalResult
717  ConversionPatternRewriter &rewriter) const final {
718  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
719  }
720  LogicalResult
722  ConversionPatternRewriter &rewriter) const final {
723  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
724  }
725 
726  /// Rewrite and Match methods that operate on the SourceOp type. These must be
727  /// overridden by the derived pattern class.
728  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
729  ConversionPatternRewriter &rewriter) const {
730  llvm_unreachable("must override matchAndRewrite or a rewrite method");
731  }
732  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
733  ConversionPatternRewriter &rewriter) const {
734  rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
735  }
736  virtual LogicalResult
737  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
738  ConversionPatternRewriter &rewriter) const {
739  if (failed(match(op)))
740  return failure();
741  rewrite(op, operands, rewriter);
742  return success();
743  }
744  virtual LogicalResult
745  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
746  ConversionPatternRewriter &rewriter) const {
747  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
748  }
749 
750 private:
752 };
753 
754 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
755 /// for matching and rewriting against instances of an operation that possess a
756 /// given trait.
757 template <template <typename> class TraitType>
759 public:
762  TypeID::get<TraitType>(), benefit, context) {}
764  MLIRContext *context, PatternBenefit benefit = 1)
766  TypeID::get<TraitType>(), benefit, context) {}
767 };
768 
769 /// Generic utility to convert op result types according to type converter
770 /// without knowing exact op type.
771 /// Clones existing op with new result types and returns it.
772 FailureOr<Operation *>
773 convertOpResultTypes(Operation *op, ValueRange operands,
774  const TypeConverter &converter,
775  ConversionPatternRewriter &rewriter);
776 
777 /// Add a pattern to the given pattern list to convert the signature of a
778 /// FunctionOpInterface op with the given type converter. This only supports
779 /// ops which use FunctionType to represent their type.
781  StringRef functionLikeOpName, RewritePatternSet &patterns,
782  const TypeConverter &converter);
783 
784 template <typename FuncOpT>
786  RewritePatternSet &patterns, const TypeConverter &converter) {
787  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
788  patterns, converter);
789 }
790 
792  RewritePatternSet &patterns, const TypeConverter &converter);
793 
794 //===----------------------------------------------------------------------===//
795 // Conversion PatternRewriter
796 //===----------------------------------------------------------------------===//
797 
798 namespace detail {
799 struct ConversionPatternRewriterImpl;
800 } // namespace detail
801 
802 /// This class implements a pattern rewriter for use with ConversionPatterns. It
803 /// extends the base PatternRewriter and provides special conversion specific
804 /// hooks.
806 public:
808 
809  /// Apply a signature conversion to given block. This replaces the block with
810  /// a new block containing the updated signature. The operations of the given
811  /// block are inlined into the newly-created block, which is returned.
812  ///
813  /// If no block argument types are changing, the original block will be
814  /// left in place and returned.
815  ///
816  /// A signature converison must be provided. (Type converters can construct
817  /// a signature conversion with `convertBlockSignature`.)
818  ///
819  /// Optionally, a type converter can be provided to build materializations.
820  /// Note: If no type converter was provided or the type converter does not
821  /// specify any suitable argument/target materialization rules, the dialect
822  /// conversion may fail to legalize unresolved materializations.
823  Block *
826  const TypeConverter *converter = nullptr);
827 
828  /// Apply a signature conversion to each block in the given region. This
829  /// replaces each block with a new block containing the updated signature. If
830  /// an updated signature would match the current signature, the respective
831  /// block is left in place as is. (See `applySignatureConversion` for
832  /// details.) The new entry block of the region is returned.
833  ///
834  /// SignatureConversions are computed with the specified type converter.
835  /// This function returns "failure" if the type converter failed to compute
836  /// a SignatureConversion for at least one block.
837  ///
838  /// Optionally, a special SignatureConversion can be specified for the entry
839  /// block. This is because the types of the entry block arguments are often
840  /// tied semantically to the operation.
841  FailureOr<Block *> convertRegionTypes(
842  Region *region, const TypeConverter &converter,
843  TypeConverter::SignatureConversion *entryConversion = nullptr);
844 
845  /// Replace all the uses of the block argument `from` with value `to`.
847 
848  /// Return the converted value of 'key' with a type defined by the type
849  /// converter of the currently executing pattern. Return nullptr in the case
850  /// of failure, the remapped value otherwise.
852 
853  /// Return the converted values that replace 'keys' with types defined by the
854  /// type converter of the currently executing pattern. Returns failure if the
855  /// remap failed, success otherwise.
856  LogicalResult getRemappedValues(ValueRange keys,
857  SmallVectorImpl<Value> &results);
858 
859  //===--------------------------------------------------------------------===//
860  // PatternRewriter Hooks
861  //===--------------------------------------------------------------------===//
862 
863  /// Indicate that the conversion rewriter can recover from rewrite failure.
864  /// Recovery is supported via rollback, allowing for continued processing of
865  /// patterns even if a failure is encountered during the rewrite step.
866  bool canRecoverFromRewriteFailure() const override { return true; }
867 
868  /// Replace the given operation with the new values. The number of op results
869  /// and replacement values must match. The types may differ: the dialect
870  /// conversion driver will reconcile any surviving type mismatches at the end
871  /// of the conversion process with source materializations. The given
872  /// operation is erased.
873  void replaceOp(Operation *op, ValueRange newValues) override;
874 
875  /// Replace the given operation with the results of the new op. The number of
876  /// op results must match. The types may differ: the dialect conversion
877  /// driver will reconcile any surviving type mismatches at the end of the
878  /// conversion process with source materializations. The original operation
879  /// is erased.
880  void replaceOp(Operation *op, Operation *newOp) override;
881 
882  /// Replace the given operation with the new value ranges. The number of op
883  /// results and value ranges must match. If an original SSA value is replaced
884  /// by multiple SSA values (i.e., a value range has more than 1 element), the
885  /// conversion driver will insert an argument materialization to convert the
886  /// N SSA values back into 1 SSA value of the original type. The given
887  /// operation is erased.
888  ///
889  /// Note: The argument materialization is a workaround until we have full 1:N
890  /// support in the dialect conversion. (It is going to disappear from both
891  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
893 
894  /// PatternRewriter hook for erasing a dead operation. The uses of this
895  /// operation *must* be made dead by the end of the conversion process,
896  /// otherwise an assert will be issued.
897  void eraseOp(Operation *op) override;
898 
899  /// PatternRewriter hook for erase all operations in a block. This is not yet
900  /// implemented for dialect conversion.
901  void eraseBlock(Block *block) override;
902 
903  /// PatternRewriter hook for inlining the ops of a block into another block.
904  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
905  ValueRange argValues = std::nullopt) override;
907 
908  /// PatternRewriter hook for updating the given operation in-place.
909  /// Note: These methods only track updates to the given operation itself,
910  /// and not nested regions. Updates to regions will still require notification
911  /// through other more specific hooks above.
912  void startOpModification(Operation *op) override;
913 
914  /// PatternRewriter hook for updating the given operation in-place.
915  void finalizeOpModification(Operation *op) override;
916 
917  /// PatternRewriter hook for updating the given operation in-place.
918  void cancelOpModification(Operation *op) override;
919 
920  /// Return a reference to the internal implementation.
922 
923 private:
924  // Allow OperationConverter to construct new rewriters.
925  friend struct OperationConverter;
926 
927  /// Conversion pattern rewriters must not be used outside of dialect
928  /// conversions. They apply some IR rewrites in a delayed fashion and could
929  /// bring the IR into an inconsistent state when used standalone.
931  const ConversionConfig &config);
932 
933  // Hide unsupported pattern rewriter API.
935 
936  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
937 };
938 
939 //===----------------------------------------------------------------------===//
940 // ConversionTarget
941 //===----------------------------------------------------------------------===//
942 
943 /// This class describes a specific conversion target.
945 public:
946  /// This enumeration corresponds to the specific action to take when
947  /// considering an operation legal for this conversion target.
948  enum class LegalizationAction {
949  /// The target supports this operation.
950  Legal,
951 
952  /// This operation has dynamic legalization constraints that must be checked
953  /// by the target.
954  Dynamic,
955 
956  /// The target explicitly does not support this operation.
957  Illegal,
958  };
959 
960  /// A structure containing additional information describing a specific legal
961  /// operation instance.
962  struct LegalOpDetails {
963  /// A flag that indicates if this operation is 'recursively' legal. This
964  /// means that if an operation is legal, either statically or dynamically,
965  /// all of the operations nested within are also considered legal.
966  bool isRecursivelyLegal = false;
967  };
968 
969  /// The signature of the callback used to determine if an operation is
970  /// dynamically legal on the target.
972  std::function<std::optional<bool>(Operation *)>;
973 
974  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
975  virtual ~ConversionTarget() = default;
976 
977  //===--------------------------------------------------------------------===//
978  // Legality Registration
979  //===--------------------------------------------------------------------===//
980 
981  /// Register a legality action for the given operation.
983  template <typename OpT>
985  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
986  }
987 
988  /// Register the given operations as legal.
991  }
992  template <typename OpT>
993  void addLegalOp() {
994  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
995  }
996  template <typename OpT, typename OpT2, typename... OpTs>
997  void addLegalOp() {
998  addLegalOp<OpT>();
999  addLegalOp<OpT2, OpTs...>();
1000  }
1001 
1002  /// Register the given operation as dynamically legal and set the dynamic
1003  /// legalization callback to the one provided.
1005  const DynamicLegalityCallbackFn &callback) {
1007  setLegalityCallback(op, callback);
1008  }
1009  template <typename OpT>
1011  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
1012  callback);
1013  }
1014  template <typename OpT, typename OpT2, typename... OpTs>
1016  addDynamicallyLegalOp<OpT>(callback);
1017  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
1018  }
1019  template <typename OpT, class Callable>
1020  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1021  addDynamicallyLegalOp(Callable &&callback) {
1022  addDynamicallyLegalOp<OpT>(
1023  [=](Operation *op) { return callback(cast<OpT>(op)); });
1024  }
1025 
1026  /// Register the given operation as illegal, i.e. this operation is known to
1027  /// not be supported by this target.
1030  }
1031  template <typename OpT>
1032  void addIllegalOp() {
1033  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
1034  }
1035  template <typename OpT, typename OpT2, typename... OpTs>
1036  void addIllegalOp() {
1037  addIllegalOp<OpT>();
1038  addIllegalOp<OpT2, OpTs...>();
1039  }
1040 
1041  /// Mark an operation, that *must* have either been set as `Legal` or
1042  /// `DynamicallyLegal`, as being recursively legal. This means that in
1043  /// addition to the operation itself, all of the operations nested within are
1044  /// also considered legal. An optional dynamic legality callback may be
1045  /// provided to mark subsets of legal instances as recursively legal.
1047  const DynamicLegalityCallbackFn &callback);
1048  template <typename OpT>
1050  OperationName opName(OpT::getOperationName(), &ctx);
1051  markOpRecursivelyLegal(opName, callback);
1052  }
1053  template <typename OpT, typename OpT2, typename... OpTs>
1055  markOpRecursivelyLegal<OpT>(callback);
1056  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
1057  }
1058  template <typename OpT, class Callable>
1059  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1060  markOpRecursivelyLegal(Callable &&callback) {
1061  markOpRecursivelyLegal<OpT>(
1062  [=](Operation *op) { return callback(cast<OpT>(op)); });
1063  }
1064 
1065  /// Register a legality action for the given dialects.
1066  void setDialectAction(ArrayRef<StringRef> dialectNames,
1067  LegalizationAction action);
1068 
1069  /// Register the operations of the given dialects as legal.
1070  template <typename... Names>
1071  void addLegalDialect(StringRef name, Names... names) {
1072  SmallVector<StringRef, 2> dialectNames({name, names...});
1074  }
1075  template <typename... Args>
1077  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1079  }
1080 
1081  /// Register the operations of the given dialects as dynamically legal, i.e.
1082  /// requiring custom handling by the callback.
1083  template <typename... Names>
1085  StringRef name, Names... names) {
1086  SmallVector<StringRef, 2> dialectNames({name, names...});
1088  setLegalityCallback(dialectNames, callback);
1089  }
1090  template <typename... Args>
1092  addDynamicallyLegalDialect(std::move(callback),
1093  Args::getDialectNamespace()...);
1094  }
1095 
1096  /// Register unknown operations as dynamically legal. For operations(and
1097  /// dialects) that do not have a set legalization action, treat them as
1098  /// dynamically legal and invoke the given callback.
1100  setLegalityCallback(fn);
1101  }
1102 
1103  /// Register the operations of the given dialects as illegal, i.e.
1104  /// operations of this dialect are not supported by the target.
1105  template <typename... Names>
1106  void addIllegalDialect(StringRef name, Names... names) {
1107  SmallVector<StringRef, 2> dialectNames({name, names...});
1109  }
1110  template <typename... Args>
1112  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1114  }
1115 
1116  //===--------------------------------------------------------------------===//
1117  // Legality Querying
1118  //===--------------------------------------------------------------------===//
1119 
1120  /// Get the legality action for the given operation.
1121  std::optional<LegalizationAction> getOpAction(OperationName op) const;
1122 
1123  /// If the given operation instance is legal on this target, a structure
1124  /// containing legality information is returned. If the operation is not
1125  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1126  /// legality wasn't registered by user or dynamic legality callbacks returned
1127  /// None.
1128  ///
1129  /// Note: Legality is actually a 4-state: Legal(recursive=true),
1130  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1131  /// either as Legal or Illegal depending on context.
1132  std::optional<LegalOpDetails> isLegal(Operation *op) const;
1133 
1134  /// Returns true is operation instance is illegal on this target. Returns
1135  /// false if operation is legal, operation legality wasn't registered by user
1136  /// or dynamic legality callbacks returned None.
1137  bool isIllegal(Operation *op) const;
1138 
1139 private:
1140  /// Set the dynamic legality callback for the given operation.
1141  void setLegalityCallback(OperationName name,
1142  const DynamicLegalityCallbackFn &callback);
1143 
1144  /// Set the dynamic legality callback for the given dialects.
1145  void setLegalityCallback(ArrayRef<StringRef> dialects,
1146  const DynamicLegalityCallbackFn &callback);
1147 
1148  /// Set the dynamic legality callback for the unknown ops.
1149  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1150 
1151  /// The set of information that configures the legalization of an operation.
1152  struct LegalizationInfo {
1153  /// The legality action this operation was given.
1155 
1156  /// If some legal instances of this operation may also be recursively legal.
1157  bool isRecursivelyLegal = false;
1158 
1159  /// The legality callback if this operation is dynamically legal.
1160  DynamicLegalityCallbackFn legalityFn;
1161  };
1162 
1163  /// Get the legalization information for the given operation.
1164  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1165 
1166  /// A deterministic mapping of operation name and its respective legality
1167  /// information.
1168  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1169 
1170  /// A set of legality callbacks for given operation names that are used to
1171  /// check if an operation instance is recursively legal.
1172  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1173 
1174  /// A deterministic mapping of dialect name to the specific legality action to
1175  /// take.
1176  llvm::StringMap<LegalizationAction> legalDialects;
1177 
1178  /// A set of dynamic legality callbacks for given dialect names.
1179  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1180 
1181  /// An optional legality callback for unknown operations.
1182  DynamicLegalityCallbackFn unknownLegalityFn;
1183 
1184  /// The current context this target applies to.
1185  MLIRContext &ctx;
1186 };
1187 
1188 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1189 //===----------------------------------------------------------------------===//
1190 // PDL Configuration
1191 //===----------------------------------------------------------------------===//
1192 
1193 /// A PDL configuration that is used to supported dialect conversion
1194 /// functionality.
1196  : public PDLPatternConfigBase<PDLConversionConfig> {
1197 public:
1198  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1199  ~PDLConversionConfig() final = default;
1200 
1201  /// Return the type converter used by this configuration, which may be nullptr
1202  /// if no type conversions are expected.
1203  const TypeConverter *getTypeConverter() const { return converter; }
1204 
1205  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1206  /// pattern.
1207  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1208  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1209 
1210 private:
1211  /// An optional type converter to use for the pattern.
1212  const TypeConverter *converter;
1213 };
1214 
1215 /// Register the dialect conversion PDL functions with the given pattern set.
1216 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1217 
1218 #else
1219 
1220 // Stubs for when PDL in rewriting is not enabled.
1221 
1222 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1223 
1224 class PDLConversionConfig final {
1225 public:
1226  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1227 };
1228 
1229 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1230 
1231 //===----------------------------------------------------------------------===//
1232 // ConversionConfig
1233 //===----------------------------------------------------------------------===//
1234 
1235 /// Dialect conversion configuration.
1237  /// An optional callback used to notify about match failure diagnostics during
1238  /// the conversion. Diagnostics reported to this callback may only be
1239  /// available in debug mode.
1241 
1242  /// Partial conversion only. All operations that are found not to be
1243  /// legalizable are placed in this set. (Note that if there is an op
1244  /// explicitly marked as illegal, the conversion terminates and the set will
1245  /// not necessarily be complete.)
1247 
1248  /// Analysis conversion only. All operations that are found to be legalizable
1249  /// are placed in this set. Note that no actual rewrites are applied to the
1250  /// IR during an analysis conversion and only pre-existing operations are
1251  /// added to the set.
1253 
1254  /// An optional listener that is notified about all IR modifications in case
1255  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1256  /// modifications are visible (i.e., they were all rolled back), or if the
1257  /// dialect conversion is an "analysis conversion", no notifications are
1258  /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1259  ///
1260  /// Note: Notifications are sent in a delayed fashion, when the dialect
1261  /// conversion is guaranteed to succeed. At that point, some IR modifications
1262  /// may already have been materialized. Consequently, operations/blocks that
1263  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1264  /// guaranteed to be valid pointers and accessing op names is allowed. But
1265  /// there are no guarantees about the state of ops/blocks at the time that a
1266  /// callback is triggered.)
1267  ///
1268  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1269  /// and inserted, and later moved to another block. (Moving ops also triggers
1270  /// "notifyOperationInserted".)
1271  ///
1272  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1273  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1274  ///
1275  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1276  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1277  /// immediately performed by the dialect conversion (and rolled back upon
1278  /// failure).
1279  //
1280  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1281  // callback, the previous region/block is provided to the callback, but not
1282  // the iterator pointing to the exact location within the region/block. That
1283  // is because these notifications are sent with a delay (after the IR has
1284  // already been modified) and iterators into past IR state cannot be
1285  // represented at the moment.
1287 
1288  /// If set to "true", the dialect conversion attempts to build source/target/
1289  /// argument materializations through the type converter API in lieu of
1290  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1291  /// at least one materialization could not be built.
1292  ///
1293  /// If set to "false", the dialect conversion does not build any custom
1294  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1295  /// ops to ensure that the resulting IR is valid.
1297 };
1298 
1299 //===----------------------------------------------------------------------===//
1300 // Reconcile Unrealized Casts
1301 //===----------------------------------------------------------------------===//
1302 
1303 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1304 /// left-over ops in `remainingCastOps` (if provided).
1305 ///
1306 /// This function processes cast ops in a worklist-driven fashion. For each
1307 /// cast op, if the chain of input casts eventually reaches a cast op where the
1308 /// input types match the output types of the matched op, replace the matched
1309 /// op with the inputs.
1310 ///
1311 /// Example:
1312 /// %1 = unrealized_conversion_cast %0 : !A to !B
1313 /// %2 = unrealized_conversion_cast %1 : !B to !C
1314 /// %3 = unrealized_conversion_cast %2 : !C to !A
1315 ///
1316 /// In the above example, %0 can be used instead of %3 and all cast ops are
1317 /// folded away.
1320  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1321 
1322 //===----------------------------------------------------------------------===//
1323 // Op Conversion Entry Points
1324 //===----------------------------------------------------------------------===//
1325 
1326 /// Below we define several entry points for operation conversion. It is
1327 /// important to note that the patterns provided to the conversion framework may
1328 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1329 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1330 /// the use of the PatternRewriter.
1331 
1332 /// Apply a partial conversion on the given operations and all nested
1333 /// operations. This method converts as many operations to the target as
1334 /// possible, ignoring operations that failed to legalize. This method only
1335 /// returns failure if there ops explicitly marked as illegal.
1336 LogicalResult
1338  const ConversionTarget &target,
1341 LogicalResult
1345 
1346 /// Apply a complete conversion on the given operations, and all nested
1347 /// operations. This method returns failure if the conversion of any operation
1348 /// fails, or if there are unreachable blocks in any of the regions nested
1349 /// within 'ops'.
1350 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1351  const ConversionTarget &target,
1354 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1357 
1358 /// Apply an analysis conversion on the given operations, and all nested
1359 /// operations. This method analyzes which operations would be successfully
1360 /// converted to the target if a conversion was applied. All operations that
1361 /// were found to be legalizable to the given 'target' are placed within the
1362 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1363 /// applied to the operations on success. This method only returns failure if
1364 /// there are unreachable blocks in any of the regions nested within 'ops'.
1365 LogicalResult
1369 LogicalResult
1373 } // namespace mlir
1374 
1375 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual void rewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:325
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
typename SourceOp::Adaptor OpAdaptor
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
void rewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
virtual void rewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
This struct represents a range of new types or a single value that remaps an existing signature input...