MLIR  22.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  sourceMaterializations(other.sourceMaterializations),
49  targetMaterializations(other.targetMaterializations),
50  typeAttributeConversions(other.typeAttributeConversions) {}
52  conversions = other.conversions;
53  sourceMaterializations = other.sourceMaterializations;
54  targetMaterializations = other.targetMaterializations;
55  typeAttributeConversions = other.typeAttributeConversions;
56  return *this;
57  }
58 
59  /// This class provides all of the information necessary to convert a type
60  /// signature.
62  public:
63  SignatureConversion(unsigned numOrigInputs)
64  : remappedInputs(numOrigInputs) {}
65 
66  /// This struct represents a range of new types or a range of values that
67  /// remaps an existing signature input.
68  struct InputMapping {
69  size_t inputNo, size;
71 
72  /// Return "true" if this input was replaces with one or multiple values.
73  bool replacedWithValues() const { return !replacementValues.empty(); }
74  };
75 
76  /// Return the argument types for the new signature.
77  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
78 
79  /// Get the input mapping for the given argument.
80  std::optional<InputMapping> getInputMapping(unsigned input) const {
81  return remappedInputs[input];
82  }
83 
84  //===------------------------------------------------------------------===//
85  // Conversion Hooks
86  //===------------------------------------------------------------------===//
87 
88  /// Remap an input of the original signature with a new set of types. The
89  /// new types are appended to the new signature conversion.
90  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
91 
92  /// Append new input types to the signature conversion, this should only be
93  /// used if the new types are not intended to remap an existing input.
94  void addInputs(ArrayRef<Type> types);
95 
96  /// Remap an input of the original signature to `replacements`
97  /// values. This drops the original argument.
98  void remapInput(unsigned origInputNo, ArrayRef<Value> replacements);
99 
100  private:
101  /// Remap an input of the original signature with a range of types in the
102  /// new signature.
103  void remapInput(unsigned origInputNo, unsigned newInputNo,
104  unsigned newInputCount = 1);
105 
106  /// The remapping information for each of the original arguments.
107  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
108 
109  /// The set of new argument types.
110  SmallVector<Type, 4> argTypes;
111  };
112 
113  /// The general result of a type attribute conversion callback, allowing
114  /// for early termination. The default constructor creates the na case.
116  public:
117  constexpr AttributeConversionResult() : impl() {}
118  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
119 
121  static AttributeConversionResult na();
123 
124  bool hasResult() const;
125  bool isNa() const;
126  bool isAbort() const;
127 
128  Attribute getResult() const;
129 
130  private:
131  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
132 
133  llvm::PointerIntPair<Attribute, 2> impl;
134  // Note that na is 0 so that we can use PointerIntPair's default
135  // constructor.
136  static constexpr unsigned naTag = 0;
137  static constexpr unsigned resultTag = 1;
138  static constexpr unsigned abortTag = 2;
139  };
140 
141  /// Register a conversion function. A conversion function must be convertible
142  /// to any of the following forms (where `T` is a class derived from `Type`):
143  ///
144  /// * std::optional<Type>(T)
145  /// - This form represents a 1-1 type conversion. It should return nullptr
146  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
147  /// the converter is allowed to try another conversion function to
148  /// perform the conversion.
149  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
150  /// - This form represents a 1-N type conversion. It should return
151  /// `failure` or `std::nullopt` to signify a failed conversion. If the
152  /// new set of types is empty, the type is removed and any usages of the
153  /// existing value are expected to be removed during conversion. If
154  /// `std::nullopt` is returned, the converter is allowed to try another
155  /// conversion function to perform the conversion.
156  ///
157  /// Note: When attempting to convert a type, e.g. via 'convertType', the
158  /// mostly recently added conversions will be invoked first.
159  template <typename FnT, typename T = typename llvm::function_traits<
160  std::decay_t<FnT>>::template arg_t<0>>
161  void addConversion(FnT &&callback) {
162  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
163  }
164 
165  /// All of the following materializations require function objects that are
166  /// convertible to the following form:
167  /// `Value(OpBuilder &, T, ValueRange, Location)`,
168  /// where `T` is any subclass of `Type`. This function is responsible for
169  /// creating an operation, using the OpBuilder and Location provided, that
170  /// "casts" a range of values into a single value of the given type `T`. It
171  /// must return a Value of the type `T` on success and `nullptr` if
172  /// it failed but other materialization should be attempted. Materialization
173  /// functions must be provided when a type conversion may persist after the
174  /// conversion has finished.
175  ///
176  /// Note: Target materializations may optionally accept an additional Type
177  /// parameter, which is the original type of the SSA value. Furthermore, `T`
178  /// can be a TypeRange; in that case, the function must return a
179  /// SmallVector<Value>.
180 
181  /// This method registers a materialization that will be called when
182  /// converting a replacement value back to its original source type.
183  /// This is used when some uses of the original value persist beyond the main
184  /// conversion.
185  template <typename FnT, typename T = typename llvm::function_traits<
186  std::decay_t<FnT>>::template arg_t<1>>
187  void addSourceMaterialization(FnT &&callback) {
188  sourceMaterializations.emplace_back(
189  wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
190  }
191 
192  /// This method registers a materialization that will be called when
193  /// converting a value to a target type according to a pattern's type
194  /// converter.
195  ///
196  /// Note: Target materializations can optionally inspect the "original"
197  /// type. This type may be different from the type of the input value.
198  /// For example, let's assume that a conversion pattern "P1" replaced an SSA
199  /// value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion
200  /// pattern "P2" matches an op that has "v1" as an operand. Let's furthermore
201  /// assume that "P2" determines that the converted target type of "t1" is
202  /// "t3", which may be different from "t2". In this example, the target
203  /// materialization will be invoked with: outputType = "t3", inputs = "v2",
204  /// originalType = "t1". Note that the original type "t1" cannot be recovered
205  /// from just "t3" and "v2"; that's why the originalType parameter exists.
206  ///
207  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
208  /// that case the materialization produces a SmallVector<Value>.
209  template <typename FnT, typename T = typename llvm::function_traits<
210  std::decay_t<FnT>>::template arg_t<1>>
211  void addTargetMaterialization(FnT &&callback) {
212  targetMaterializations.emplace_back(
213  wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
214  }
215 
216  /// Register a conversion function for attributes within types. Type
217  /// converters may call this function in order to allow hoking into the
218  /// translation of attributes that exist within types. For example, a type
219  /// converter for the `memref` type could use these conversions to convert
220  /// memory spaces or layouts in an extensible way.
221  ///
222  /// The conversion functions take a non-null Type or subclass of Type and a
223  /// non-null Attribute (or subclass of Attribute), and returns a
224  /// `AttributeConversionResult`. This result can either contain an
225  /// `Attribute`, which may be `nullptr`, representing the conversion's
226  /// success, `AttributeConversionResult::na()` (the default empty value),
227  /// indicating that the conversion function did not apply and that further
228  /// conversion functions should be checked, or
229  /// `AttributeConversionResult::abort()` indicating that the conversion
230  /// process should be aborted.
231  ///
232  /// Registered conversion functions are callled in the reverse of the order in
233  /// which they were registered.
234  template <
235  typename FnT,
236  typename T =
237  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
238  typename A =
239  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
240  void addTypeAttributeConversion(FnT &&callback) {
241  registerTypeAttributeConversion(
242  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
243  }
244 
245  /// Convert the given type. This function should return failure if no valid
246  /// conversion exists, success otherwise. If the new set of types is empty,
247  /// the type is removed and any usages of the existing value are expected to
248  /// be removed during conversion.
249  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
250 
251  /// This hook simplifies defining 1-1 type conversions. This function returns
252  /// the type to convert to on success, and a null type on failure.
253  Type convertType(Type t) const;
254 
255  /// Attempts a 1-1 type conversion, expecting the result type to be
256  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
257  /// and a null type on conversion or cast failure.
258  template <typename TargetType>
259  TargetType convertType(Type t) const {
260  return dyn_cast_or_null<TargetType>(convertType(t));
261  }
262 
263  /// Convert the given set of types, filling 'results' as necessary. This
264  /// returns failure if the conversion of any of the types fails, success
265  /// otherwise.
266  LogicalResult convertTypes(TypeRange types,
267  SmallVectorImpl<Type> &results) const;
268 
269  /// Return true if the given type is legal for this type converter, i.e. the
270  /// type converts to itself.
271  bool isLegal(Type type) const;
272 
273  /// Return true if all of the given types are legal for this type converter.
274  template <typename RangeT>
275  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
276  !std::is_convertible<RangeT, Operation *>::value,
277  bool>
278  isLegal(RangeT &&range) const {
279  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
280  }
281  /// Return true if the given operation has legal operand and result types.
282  bool isLegal(Operation *op) const;
283 
284  /// Return true if the types of block arguments within the region are legal.
285  bool isLegal(Region *region) const;
286 
287  /// Return true if the inputs and outputs of the given function type are
288  /// legal.
289  bool isSignatureLegal(FunctionType ty) const;
290 
291  /// This method allows for converting a specific argument of a signature. It
292  /// takes as inputs the original argument input number, type.
293  /// On success, it populates 'result' with any new mappings.
294  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
295  SignatureConversion &result) const;
296  LogicalResult convertSignatureArgs(TypeRange types,
297  SignatureConversion &result,
298  unsigned origInputOffset = 0) const;
299 
300  /// This function converts the type signature of the given block, by invoking
301  /// 'convertSignatureArg' for each argument. This function should return a
302  /// valid conversion for the signature on success, std::nullopt otherwise.
303  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
304 
305  /// Materialize a conversion from a set of types into one result type by
306  /// generating a cast sequence of some kind. See the respective
307  /// `add*Materialization` for more information on the context for these
308  /// methods.
310  Type resultType, ValueRange inputs) const;
312  Type resultType, ValueRange inputs,
313  Type originalType = {}) const;
314  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
315  Location loc,
316  TypeRange resultType,
317  ValueRange inputs,
318  Type originalType = {}) const;
319 
320  /// Convert an attribute present `attr` from within the type `type` using
321  /// the registered conversion functions. If no applicable conversion has been
322  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
323  /// is a valid return value for this function.
324  std::optional<Attribute> convertTypeAttribute(Type type,
325  Attribute attr) const;
326 
327 private:
328  /// The signature of the callback used to convert a type. If the new set of
329  /// types is empty, the type is removed and any usages of the existing value
330  /// are expected to be removed during conversion.
331  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
332  Type, SmallVectorImpl<Type> &)>;
333 
334  /// The signature of the callback used to materialize a source conversion.
335  ///
336  /// Arguments: builder, result type, inputs, location
337  using SourceMaterializationCallbackFn =
338  std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
339 
340  /// The signature of the callback used to materialize a target conversion.
341  ///
342  /// Arguments: builder, result types, inputs, location, original type
343  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
344  OpBuilder &, TypeRange, ValueRange, Location, Type)>;
345 
346  /// The signature of the callback used to convert a type attribute.
347  using TypeAttributeConversionCallbackFn =
348  std::function<AttributeConversionResult(Type, Attribute)>;
349 
350  /// Generate a wrapper for the given callback. This allows for accepting
351  /// different callback forms, that all compose into a single version.
352  /// With callback of form: `std::optional<Type>(T)`
353  template <typename T, typename FnT>
354  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
355  wrapCallback(FnT &&callback) const {
356  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
357  T type, SmallVectorImpl<Type> &results) {
358  if (std::optional<Type> resultOpt = callback(type)) {
359  bool wasSuccess = static_cast<bool>(*resultOpt);
360  if (wasSuccess)
361  results.push_back(*resultOpt);
362  return std::optional<LogicalResult>(success(wasSuccess));
363  }
364  return std::optional<LogicalResult>();
365  });
366  }
367  /// With callback of form: `std::optional<LogicalResult>(
368  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
369  template <typename T, typename FnT>
370  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
371  ConversionCallbackFn>
372  wrapCallback(FnT &&callback) const {
373  return [callback = std::forward<FnT>(callback)](
374  Type type,
375  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
376  T derivedType = dyn_cast<T>(type);
377  if (!derivedType)
378  return std::nullopt;
379  return callback(derivedType, results);
380  };
381  }
382 
383  /// Register a type conversion.
384  void registerConversion(ConversionCallbackFn callback) {
385  conversions.emplace_back(std::move(callback));
386  cachedDirectConversions.clear();
387  cachedMultiConversions.clear();
388  }
389 
390  /// Generate a wrapper for the given source materialization callback. The
391  /// callback may take any subclass of `Type` and the wrapper will check for
392  /// the target type to be of the expected class before calling the callback.
393  template <typename T, typename FnT>
394  SourceMaterializationCallbackFn
395  wrapSourceMaterialization(FnT &&callback) const {
396  return [callback = std::forward<FnT>(callback)](
397  OpBuilder &builder, Type resultType, ValueRange inputs,
398  Location loc) -> Value {
399  if (T derivedType = dyn_cast<T>(resultType))
400  return callback(builder, derivedType, inputs, loc);
401  return Value();
402  };
403  }
404 
405  /// Generate a wrapper for the given target materialization callback.
406  /// The callback may take any subclass of `Type` and the wrapper will check
407  /// for the target type to be of the expected class before calling the
408  /// callback.
409  ///
410  /// With callback of form:
411  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
412  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
413  template <typename T, typename FnT>
414  std::enable_if_t<
415  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
416  TargetMaterializationCallbackFn>
417  wrapTargetMaterialization(FnT &&callback) const {
418  return [callback = std::forward<FnT>(callback)](
419  OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
420  Location loc, Type originalType) -> SmallVector<Value> {
421  SmallVector<Value> result;
422  if constexpr (std::is_same<T, TypeRange>::value) {
423  // This is a 1:N target materialization. Return the produces values
424  // directly.
425  result = callback(builder, resultTypes, inputs, loc, originalType);
426  } else if constexpr (std::is_assignable<Type, T>::value) {
427  // This is a 1:1 target materialization. Invoke the callback only if a
428  // single SSA value is requested.
429  if (resultTypes.size() == 1) {
430  // Invoke the callback only if the type class of the callback matches
431  // the requested result type.
432  if (T derivedType = dyn_cast<T>(resultTypes.front())) {
433  // 1:1 materializations produce single values, but we store 1:N
434  // target materialization functions in the type converter. Wrap the
435  // result value in a SmallVector<Value>.
436  Value val =
437  callback(builder, derivedType, inputs, loc, originalType);
438  if (val)
439  result.push_back(val);
440  }
441  }
442  } else {
443  static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
444  }
445  return result;
446  };
447  }
448  /// With callback of form:
449  /// - Value(OpBuilder &, T, ValueRange, Location)
450  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
451  template <typename T, typename FnT>
452  std::enable_if_t<
453  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
454  TargetMaterializationCallbackFn>
455  wrapTargetMaterialization(FnT &&callback) const {
456  return wrapTargetMaterialization<T>(
457  [callback = std::forward<FnT>(callback)](
458  OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
459  Type originalType) {
460  return callback(builder, resultTypes, inputs, loc);
461  });
462  }
463 
464  /// Generate a wrapper for the given memory space conversion callback. The
465  /// callback may take any subclass of `Attribute` and the wrapper will check
466  /// for the target attribute to be of the expected class before calling the
467  /// callback.
468  template <typename T, typename A, typename FnT>
469  TypeAttributeConversionCallbackFn
470  wrapTypeAttributeConversion(FnT &&callback) const {
471  return [callback = std::forward<FnT>(callback)](
472  Type type, Attribute attr) -> AttributeConversionResult {
473  if (T derivedType = dyn_cast<T>(type)) {
474  if (A derivedAttr = dyn_cast_or_null<A>(attr))
475  return callback(derivedType, derivedAttr);
476  }
478  };
479  }
480 
481  /// Register a memory space conversion, clearing caches.
482  void
483  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
484  typeAttributeConversions.emplace_back(std::move(callback));
485  // Clear type conversions in case a memory space is lingering inside.
486  cachedDirectConversions.clear();
487  cachedMultiConversions.clear();
488  }
489 
490  /// The set of registered conversion functions.
491  SmallVector<ConversionCallbackFn, 4> conversions;
492 
493  /// The list of registered materialization functions.
494  SmallVector<SourceMaterializationCallbackFn, 2> sourceMaterializations;
495  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
496 
497  /// The list of registered type attribute conversion functions.
498  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
499 
500  /// A set of cached conversions to avoid recomputing in the common case.
501  /// Direct 1-1 conversions are the most common, so this cache stores the
502  /// successful 1-1 conversions as well as all failed conversions.
503  mutable DenseMap<Type, Type> cachedDirectConversions;
504  /// This cache stores the successful 1->N conversions, where N != 1.
505  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
506  /// A mutex used for cache access
507  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
508 };
509 
510 //===----------------------------------------------------------------------===//
511 // Conversion Patterns
512 //===----------------------------------------------------------------------===//
513 
514 /// Base class for the conversion patterns. This pattern class enables type
515 /// conversions, and other uses specific to the conversion framework. As such,
516 /// patterns of this type can only be used with the 'apply*' methods below.
518 public:
521 
522  /// Hook for derived classes to implement combined matching and rewriting.
523  /// This overload supports only 1:1 replacements. The 1:N overload is called
524  /// by the driver. By default, it calls this 1:1 overload or reports a fatal
525  /// error if 1:N replacements were found.
526  virtual LogicalResult
528  ConversionPatternRewriter &rewriter) const {
529  llvm_unreachable("matchAndRewrite is not implemented");
530  }
531 
532  /// Hook for derived classes to implement combined matching and rewriting.
533  /// This overload supports 1:N replacements.
534  virtual LogicalResult
536  ConversionPatternRewriter &rewriter) const {
537  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
538  }
539 
540  /// Attempt to match and rewrite the IR root at the specified operation.
541  LogicalResult matchAndRewrite(Operation *op,
542  PatternRewriter &rewriter) const final;
543 
544  /// Return the type converter held by this pattern, or nullptr if the pattern
545  /// does not require type conversion.
546  const TypeConverter *getTypeConverter() const { return typeConverter; }
547 
548  template <typename ConverterTy>
549  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
550  const ConverterTy *>
552  return static_cast<const ConverterTy *>(typeConverter);
553  }
554 
555 protected:
556  /// See `RewritePattern::RewritePattern` for information on the other
557  /// available constructors.
558  using RewritePattern::RewritePattern;
559  /// Construct a conversion pattern with the given converter, and forward the
560  /// remaining arguments to RewritePattern.
561  template <typename... Args>
563  : RewritePattern(std::forward<Args>(args)...),
565 
566  /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
567  /// try to extract the single value of each range to construct a the inputs
568  /// for a 1:1 adaptor.
569  ///
570  /// This function produces a fatal error if at least one range has 0 or
571  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
574 
575 protected:
576  /// An optional type converter for use by this pattern.
577  const TypeConverter *typeConverter = nullptr;
578 };
579 
580 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
581 /// matching and rewriting against an instance of a derived operation class as
582 /// opposed to a raw Operation.
583 template <typename SourceOp>
585 public:
586  using OpAdaptor = typename SourceOp::Adaptor;
588  typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
589 
591  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
593  PatternBenefit benefit = 1)
594  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
595  context) {}
596 
597  /// Wrappers around the ConversionPattern methods that pass the derived op
598  /// type.
599  LogicalResult
601  ConversionPatternRewriter &rewriter) const final {
602  auto sourceOp = cast<SourceOp>(op);
603  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
604  }
605  LogicalResult
607  ConversionPatternRewriter &rewriter) const final {
608  auto sourceOp = cast<SourceOp>(op);
609  return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
610  rewriter);
611  }
612 
613  /// Methods that operate on the SourceOp type. One of these must be
614  /// overridden by the derived pattern class.
615  virtual LogicalResult
616  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
617  ConversionPatternRewriter &rewriter) const {
618  llvm_unreachable("matchAndRewrite is not implemented");
619  }
620  virtual LogicalResult
621  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
622  ConversionPatternRewriter &rewriter) const {
623  SmallVector<Value> oneToOneOperands =
624  getOneToOneAdaptorOperands(adaptor.getOperands());
625  return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
626  }
627 
628 private:
630 };
631 
632 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
633 /// allows for matching and rewriting against an instance of an OpInterface
634 /// class as opposed to a raw Operation.
635 template <typename SourceOp>
637 public:
640  SourceOp::getInterfaceID(), benefit, context) {}
642  MLIRContext *context, PatternBenefit benefit = 1)
644  SourceOp::getInterfaceID(), benefit, context) {}
645 
646  /// Wrappers around the ConversionPattern methods that pass the derived op
647  /// type.
648  LogicalResult
650  ConversionPatternRewriter &rewriter) const final {
651  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
652  }
653  LogicalResult
655  ConversionPatternRewriter &rewriter) const final {
656  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
657  }
658 
659  /// Methods that operate on the SourceOp type. One of these must be
660  /// overridden by the derived pattern class.
661  virtual LogicalResult
662  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
663  ConversionPatternRewriter &rewriter) const {
664  llvm_unreachable("matchAndRewrite is not implemented");
665  }
666  virtual LogicalResult
667  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
668  ConversionPatternRewriter &rewriter) const {
669  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
670  }
671 
672 private:
674 };
675 
676 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
677 /// for matching and rewriting against instances of an operation that possess a
678 /// given trait.
679 template <template <typename> class TraitType>
681 public:
684  TypeID::get<TraitType>(), benefit, context) {}
686  MLIRContext *context, PatternBenefit benefit = 1)
688  TypeID::get<TraitType>(), benefit, context) {}
689 };
690 
691 /// Generic utility to convert op result types according to type converter
692 /// without knowing exact op type.
693 /// Clones existing op with new result types and returns it.
694 FailureOr<Operation *>
695 convertOpResultTypes(Operation *op, ValueRange operands,
696  const TypeConverter &converter,
697  ConversionPatternRewriter &rewriter);
698 
699 /// Add a pattern to the given pattern list to convert the signature of a
700 /// FunctionOpInterface op with the given type converter. This only supports
701 /// ops which use FunctionType to represent their type.
703  StringRef functionLikeOpName, RewritePatternSet &patterns,
704  const TypeConverter &converter);
705 
706 template <typename FuncOpT>
708  RewritePatternSet &patterns, const TypeConverter &converter) {
709  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
710  patterns, converter);
711 }
712 
714  RewritePatternSet &patterns, const TypeConverter &converter);
715 
716 //===----------------------------------------------------------------------===//
717 // Conversion PatternRewriter
718 //===----------------------------------------------------------------------===//
719 
720 namespace detail {
721 struct ConversionPatternRewriterImpl;
722 } // namespace detail
723 
724 /// This class implements a pattern rewriter for use with ConversionPatterns. It
725 /// extends the base PatternRewriter and provides special conversion specific
726 /// hooks.
728 public:
730 
731  /// Apply a signature conversion to given block. This replaces the block with
732  /// a new block containing the updated signature. The operations of the given
733  /// block are inlined into the newly-created block, which is returned.
734  ///
735  /// If no block argument types are changing, the original block will be
736  /// left in place and returned.
737  ///
738  /// A signature converison must be provided. (Type converters can construct
739  /// a signature conversion with `convertBlockSignature`.)
740  ///
741  /// Optionally, a type converter can be provided to build materializations.
742  /// Note: If no type converter was provided or the type converter does not
743  /// specify any suitable source/target materialization rules, the dialect
744  /// conversion may fail to legalize unresolved materializations.
745  Block *
748  const TypeConverter *converter = nullptr);
749 
750  /// Apply a signature conversion to each block in the given region. This
751  /// replaces each block with a new block containing the updated signature. If
752  /// an updated signature would match the current signature, the respective
753  /// block is left in place as is. (See `applySignatureConversion` for
754  /// details.) The new entry block of the region is returned.
755  ///
756  /// SignatureConversions are computed with the specified type converter.
757  /// This function returns "failure" if the type converter failed to compute
758  /// a SignatureConversion for at least one block.
759  ///
760  /// Optionally, a special SignatureConversion can be specified for the entry
761  /// block. This is because the types of the entry block arguments are often
762  /// tied semantically to the operation.
763  FailureOr<Block *> convertRegionTypes(
764  Region *region, const TypeConverter &converter,
765  TypeConverter::SignatureConversion *entryConversion = nullptr);
766 
767  /// Replace all the uses of the block argument `from` with `to`. This
768  /// function supports both 1:1 and 1:N replacements.
770 
771  /// Return the converted value of 'key' with a type defined by the type
772  /// converter of the currently executing pattern. Return nullptr in the case
773  /// of failure, the remapped value otherwise.
775 
776  /// Return the converted values that replace 'keys' with types defined by the
777  /// type converter of the currently executing pattern. Returns failure if the
778  /// remap failed, success otherwise.
779  LogicalResult getRemappedValues(ValueRange keys,
780  SmallVectorImpl<Value> &results);
781 
782  //===--------------------------------------------------------------------===//
783  // PatternRewriter Hooks
784  //===--------------------------------------------------------------------===//
785 
786  /// Indicate that the conversion rewriter can recover from rewrite failure.
787  /// Recovery is supported via rollback, allowing for continued processing of
788  /// patterns even if a failure is encountered during the rewrite step.
789  bool canRecoverFromRewriteFailure() const override { return true; }
790 
791  /// Replace the given operation with the new values. The number of op results
792  /// and replacement values must match. The types may differ: the dialect
793  /// conversion driver will reconcile any surviving type mismatches at the end
794  /// of the conversion process with source materializations. The given
795  /// operation is erased.
796  void replaceOp(Operation *op, ValueRange newValues) override;
797 
798  /// Replace the given operation with the results of the new op. The number of
799  /// op results must match. The types may differ: the dialect conversion
800  /// driver will reconcile any surviving type mismatches at the end of the
801  /// conversion process with source materializations. The original operation
802  /// is erased.
803  void replaceOp(Operation *op, Operation *newOp) override;
804 
805  /// Replace the given operation with the new value ranges. The number of op
806  /// results and value ranges must match. The given operation is erased.
808  SmallVector<SmallVector<Value>> &&newValues);
809  template <typename RangeT = ValueRange>
812  llvm::to_vector_of<SmallVector<Value>>(newValues));
813  }
814  template <typename RangeT>
815  void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
817  ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
818  }
819 
820  /// PatternRewriter hook for erasing a dead operation. The uses of this
821  /// operation *must* be made dead by the end of the conversion process,
822  /// otherwise an assert will be issued.
823  void eraseOp(Operation *op) override;
824 
825  /// PatternRewriter hook for erase all operations in a block. This is not yet
826  /// implemented for dialect conversion.
827  void eraseBlock(Block *block) override;
828 
829  /// PatternRewriter hook for inlining the ops of a block into another block.
830  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
831  ValueRange argValues = {}) override;
833 
834  /// PatternRewriter hook for updating the given operation in-place.
835  /// Note: These methods only track updates to the given operation itself,
836  /// and not nested regions. Updates to regions will still require notification
837  /// through other more specific hooks above.
838  void startOpModification(Operation *op) override;
839 
840  /// PatternRewriter hook for updating the given operation in-place.
841  void finalizeOpModification(Operation *op) override;
842 
843  /// PatternRewriter hook for updating the given operation in-place.
844  void cancelOpModification(Operation *op) override;
845 
846  /// Return a reference to the internal implementation.
847  detail::ConversionPatternRewriterImpl &getImpl();
848 
849 private:
850  // Allow OperationConverter to construct new rewriters.
851  friend struct OperationConverter;
852 
853  /// Conversion pattern rewriters must not be used outside of dialect
854  /// conversions. They apply some IR rewrites in a delayed fashion and could
855  /// bring the IR into an inconsistent state when used standalone.
857  const ConversionConfig &config);
858 
859  // Hide unsupported pattern rewriter API.
861 
862  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
863 };
864 
865 //===----------------------------------------------------------------------===//
866 // ConversionTarget
867 //===----------------------------------------------------------------------===//
868 
869 /// This class describes a specific conversion target.
871 public:
872  /// This enumeration corresponds to the specific action to take when
873  /// considering an operation legal for this conversion target.
874  enum class LegalizationAction {
875  /// The target supports this operation.
876  Legal,
877 
878  /// This operation has dynamic legalization constraints that must be checked
879  /// by the target.
880  Dynamic,
881 
882  /// The target explicitly does not support this operation.
883  Illegal,
884  };
885 
886  /// A structure containing additional information describing a specific legal
887  /// operation instance.
888  struct LegalOpDetails {
889  /// A flag that indicates if this operation is 'recursively' legal. This
890  /// means that if an operation is legal, either statically or dynamically,
891  /// all of the operations nested within are also considered legal.
892  bool isRecursivelyLegal = false;
893  };
894 
895  /// The signature of the callback used to determine if an operation is
896  /// dynamically legal on the target.
898  std::function<std::optional<bool>(Operation *)>;
899 
900  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
901  virtual ~ConversionTarget() = default;
902 
903  //===--------------------------------------------------------------------===//
904  // Legality Registration
905  //===--------------------------------------------------------------------===//
906 
907  /// Register a legality action for the given operation.
909  template <typename OpT>
911  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
912  }
913 
914  /// Register the given operations as legal.
917  }
918  template <typename OpT>
919  void addLegalOp() {
920  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
921  }
922  template <typename OpT, typename OpT2, typename... OpTs>
923  void addLegalOp() {
924  addLegalOp<OpT>();
925  addLegalOp<OpT2, OpTs...>();
926  }
927 
928  /// Register the given operation as dynamically legal and set the dynamic
929  /// legalization callback to the one provided.
931  const DynamicLegalityCallbackFn &callback) {
933  setLegalityCallback(op, callback);
934  }
935  template <typename OpT>
937  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
938  callback);
939  }
940  template <typename OpT, typename OpT2, typename... OpTs>
942  addDynamicallyLegalOp<OpT>(callback);
943  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
944  }
945  template <typename OpT, class Callable>
946  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
947  addDynamicallyLegalOp(Callable &&callback) {
948  addDynamicallyLegalOp<OpT>(
949  [=](Operation *op) { return callback(cast<OpT>(op)); });
950  }
951 
952  /// Register the given operation as illegal, i.e. this operation is known to
953  /// not be supported by this target.
956  }
957  template <typename OpT>
958  void addIllegalOp() {
959  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
960  }
961  template <typename OpT, typename OpT2, typename... OpTs>
962  void addIllegalOp() {
963  addIllegalOp<OpT>();
964  addIllegalOp<OpT2, OpTs...>();
965  }
966 
967  /// Mark an operation, that *must* have either been set as `Legal` or
968  /// `DynamicallyLegal`, as being recursively legal. This means that in
969  /// addition to the operation itself, all of the operations nested within are
970  /// also considered legal. An optional dynamic legality callback may be
971  /// provided to mark subsets of legal instances as recursively legal.
973  const DynamicLegalityCallbackFn &callback);
974  template <typename OpT>
976  OperationName opName(OpT::getOperationName(), &ctx);
977  markOpRecursivelyLegal(opName, callback);
978  }
979  template <typename OpT, typename OpT2, typename... OpTs>
981  markOpRecursivelyLegal<OpT>(callback);
982  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
983  }
984  template <typename OpT, class Callable>
985  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
986  markOpRecursivelyLegal(Callable &&callback) {
987  markOpRecursivelyLegal<OpT>(
988  [=](Operation *op) { return callback(cast<OpT>(op)); });
989  }
990 
991  /// Register a legality action for the given dialects.
992  void setDialectAction(ArrayRef<StringRef> dialectNames,
993  LegalizationAction action);
994 
995  /// Register the operations of the given dialects as legal.
996  template <typename... Names>
997  void addLegalDialect(StringRef name, Names... names) {
998  SmallVector<StringRef, 2> dialectNames({name, names...});
1000  }
1001  template <typename... Args>
1003  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1005  }
1006 
1007  /// Register the operations of the given dialects as dynamically legal, i.e.
1008  /// requiring custom handling by the callback.
1009  template <typename... Names>
1011  StringRef name, Names... names) {
1012  SmallVector<StringRef, 2> dialectNames({name, names...});
1014  setLegalityCallback(dialectNames, callback);
1015  }
1016  template <typename... Args>
1018  addDynamicallyLegalDialect(std::move(callback),
1019  Args::getDialectNamespace()...);
1020  }
1021 
1022  /// Register unknown operations as dynamically legal. For operations(and
1023  /// dialects) that do not have a set legalization action, treat them as
1024  /// dynamically legal and invoke the given callback.
1026  setLegalityCallback(fn);
1027  }
1028 
1029  /// Register the operations of the given dialects as illegal, i.e.
1030  /// operations of this dialect are not supported by the target.
1031  template <typename... Names>
1032  void addIllegalDialect(StringRef name, Names... names) {
1033  SmallVector<StringRef, 2> dialectNames({name, names...});
1035  }
1036  template <typename... Args>
1038  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1040  }
1041 
1042  //===--------------------------------------------------------------------===//
1043  // Legality Querying
1044  //===--------------------------------------------------------------------===//
1045 
1046  /// Get the legality action for the given operation.
1047  std::optional<LegalizationAction> getOpAction(OperationName op) const;
1048 
1049  /// If the given operation instance is legal on this target, a structure
1050  /// containing legality information is returned. If the operation is not
1051  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1052  /// legality wasn't registered by user or dynamic legality callbacks returned
1053  /// None.
1054  ///
1055  /// Note: Legality is actually a 4-state: Legal(recursive=true),
1056  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1057  /// either as Legal or Illegal depending on context.
1058  std::optional<LegalOpDetails> isLegal(Operation *op) const;
1059 
1060  /// Returns true is operation instance is illegal on this target. Returns
1061  /// false if operation is legal, operation legality wasn't registered by user
1062  /// or dynamic legality callbacks returned None.
1063  bool isIllegal(Operation *op) const;
1064 
1065 private:
1066  /// Set the dynamic legality callback for the given operation.
1067  void setLegalityCallback(OperationName name,
1068  const DynamicLegalityCallbackFn &callback);
1069 
1070  /// Set the dynamic legality callback for the given dialects.
1071  void setLegalityCallback(ArrayRef<StringRef> dialects,
1072  const DynamicLegalityCallbackFn &callback);
1073 
1074  /// Set the dynamic legality callback for the unknown ops.
1075  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1076 
1077  /// The set of information that configures the legalization of an operation.
1078  struct LegalizationInfo {
1079  /// The legality action this operation was given.
1081 
1082  /// If some legal instances of this operation may also be recursively legal.
1083  bool isRecursivelyLegal = false;
1084 
1085  /// The legality callback if this operation is dynamically legal.
1086  DynamicLegalityCallbackFn legalityFn;
1087  };
1088 
1089  /// Get the legalization information for the given operation.
1090  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1091 
1092  /// A deterministic mapping of operation name and its respective legality
1093  /// information.
1094  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1095 
1096  /// A set of legality callbacks for given operation names that are used to
1097  /// check if an operation instance is recursively legal.
1098  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1099 
1100  /// A deterministic mapping of dialect name to the specific legality action to
1101  /// take.
1102  llvm::StringMap<LegalizationAction> legalDialects;
1103 
1104  /// A set of dynamic legality callbacks for given dialect names.
1105  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1106 
1107  /// An optional legality callback for unknown operations.
1108  DynamicLegalityCallbackFn unknownLegalityFn;
1109 
1110  /// The current context this target applies to.
1111  MLIRContext &ctx;
1112 };
1113 
1114 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1115 //===----------------------------------------------------------------------===//
1116 // PDL Configuration
1117 //===----------------------------------------------------------------------===//
1118 
1119 /// A PDL configuration that is used to supported dialect conversion
1120 /// functionality.
1122  : public PDLPatternConfigBase<PDLConversionConfig> {
1123 public:
1124  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1125  ~PDLConversionConfig() final = default;
1126 
1127  /// Return the type converter used by this configuration, which may be nullptr
1128  /// if no type conversions are expected.
1129  const TypeConverter *getTypeConverter() const { return converter; }
1130 
1131  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1132  /// pattern.
1133  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1134  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1135 
1136 private:
1137  /// An optional type converter to use for the pattern.
1138  const TypeConverter *converter;
1139 };
1140 
1141 /// Register the dialect conversion PDL functions with the given pattern set.
1142 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1143 
1144 #else
1145 
1146 // Stubs for when PDL in rewriting is not enabled.
1147 
1148 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1149 
1150 class PDLConversionConfig final {
1151 public:
1152  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1153 };
1154 
1155 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1156 
1157 //===----------------------------------------------------------------------===//
1158 // ConversionConfig
1159 //===----------------------------------------------------------------------===//
1160 
1161 /// Dialect conversion configuration.
1163  /// An optional callback used to notify about match failure diagnostics during
1164  /// the conversion. Diagnostics reported to this callback may only be
1165  /// available in debug mode.
1167 
1168  /// Partial conversion only. All operations that are found not to be
1169  /// legalizable are placed in this set. (Note that if there is an op
1170  /// explicitly marked as illegal, the conversion terminates and the set will
1171  /// not necessarily be complete.)
1173 
1174  /// Analysis conversion only. All operations that are found to be legalizable
1175  /// are placed in this set. Note that no actual rewrites are applied to the
1176  /// IR during an analysis conversion and only pre-existing operations are
1177  /// added to the set.
1179 
1180  /// An optional listener that is notified about all IR modifications in case
1181  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1182  /// modifications are visible (i.e., they were all rolled back), or if the
1183  /// dialect conversion is an "analysis conversion", no notifications are
1184  /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1185  ///
1186  /// Note: Notifications are sent in a delayed fashion, when the dialect
1187  /// conversion is guaranteed to succeed. At that point, some IR modifications
1188  /// may already have been materialized. Consequently, operations/blocks that
1189  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1190  /// guaranteed to be valid pointers and accessing op names is allowed. But
1191  /// there are no guarantees about the state of ops/blocks at the time that a
1192  /// callback is triggered.)
1193  ///
1194  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1195  /// and inserted, and later moved to another block. (Moving ops also triggers
1196  /// "notifyOperationInserted".)
1197  ///
1198  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1199  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1200  ///
1201  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1202  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1203  /// immediately performed by the dialect conversion (and rolled back upon
1204  /// failure).
1205  //
1206  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1207  // callback, the previous region/block is provided to the callback, but not
1208  // the iterator pointing to the exact location within the region/block. That
1209  // is because these notifications are sent with a delay (after the IR has
1210  // already been modified) and iterators into past IR state cannot be
1211  // represented at the moment.
1213 
1214  /// If set to "true", the dialect conversion attempts to build source/target
1215  /// materializations through the type converter API in lieu of
1216  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1217  /// at least one materialization could not be built.
1218  ///
1219  /// If set to "false", the dialect conversion does not build any custom
1220  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1221  /// ops to ensure that the resulting IR is valid.
1223 
1224  /// If set to "true", pattern rollback is allowed. The conversion driver
1225  /// rolls back IR modifications in the following situations.
1226  ///
1227  /// 1. Pattern implementation returns "failure" after modifying IR.
1228  /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
1229  /// and cannot be legalized by subsequent foldings / pattern applications.
1230  ///
1231  /// If set to "false", the conversion driver will produce an LLVM fatal error
1232  /// instead of rolling back IR modifications. Moreover, in case of a failed
1233  /// conversion, the original IR is not restored. The resulting IR may be a
1234  /// mix of original and rewritten IR. (Same as a failed greedy pattern
1235  /// rewrite.)
1236  ///
1237  /// Note: This flag was added in preparation of the One-Shot Dialect
1238  /// Conversion refactoring, which will remove the ability to roll back IR
1239  /// modifications from the conversion driver. Use this flag to ensure that
1240  /// your patterns do not trigger any IR rollbacks. For details, see
1241  /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
1243 };
1244 
1245 //===----------------------------------------------------------------------===//
1246 // Reconcile Unrealized Casts
1247 //===----------------------------------------------------------------------===//
1248 
1249 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1250 /// left-over ops in `remainingCastOps` (if provided).
1251 ///
1252 /// This function processes cast ops in a worklist-driven fashion. For each
1253 /// cast op, if the chain of input casts eventually reaches a cast op where the
1254 /// input types match the output types of the matched op, replace the matched
1255 /// op with the inputs.
1256 ///
1257 /// Example:
1258 /// %1 = unrealized_conversion_cast %0 : !A to !B
1259 /// %2 = unrealized_conversion_cast %1 : !B to !C
1260 /// %3 = unrealized_conversion_cast %2 : !C to !A
1261 ///
1262 /// In the above example, %0 can be used instead of %3 and all cast ops are
1263 /// folded away.
1266  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1267 
1268 //===----------------------------------------------------------------------===//
1269 // Op Conversion Entry Points
1270 //===----------------------------------------------------------------------===//
1271 
1272 /// Below we define several entry points for operation conversion. It is
1273 /// important to note that the patterns provided to the conversion framework may
1274 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1275 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1276 /// the use of the PatternRewriter.
1277 
1278 /// Apply a partial conversion on the given operations and all nested
1279 /// operations. This method converts as many operations to the target as
1280 /// possible, ignoring operations that failed to legalize. This method only
1281 /// returns failure if there ops explicitly marked as illegal.
1282 LogicalResult
1284  const ConversionTarget &target,
1287 LogicalResult
1291 
1292 /// Apply a complete conversion on the given operations, and all nested
1293 /// operations. This method returns failure if the conversion of any operation
1294 /// fails, or if there are unreachable blocks in any of the regions nested
1295 /// within 'ops'.
1296 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1297  const ConversionTarget &target,
1300 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1303 
1304 /// Apply an analysis conversion on the given operations, and all nested
1305 /// operations. This method analyzes which operations would be successfully
1306 /// converted to the target if a conversion was applied. All operations that
1307 /// were found to be legalizable to the given 'target' are placed within the
1308 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1309 /// applied to the operations on success. This method only returns failure if
1310 /// there are unreachable blocks in any of the regions nested within 'ops'.
1311 LogicalResult
1315 LogicalResult
1319 } // namespace mlir
1320 
1321 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void replaceOpWithMultiple(Operation *op, ArrayRef< RangeT > newValues)
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, RangeT &&newValues)
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to)
Replace all the uses of the block argument from with to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
typename SourceOp::Adaptor OpAdaptor
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
This struct represents a range of new types or a range of values that remaps an existing signature in...
bool replacedWithValues() const
Return "true" if this input was replaces with one or multiple values.