MLIR  21.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  sourceMaterializations(other.sourceMaterializations),
49  targetMaterializations(other.targetMaterializations),
50  typeAttributeConversions(other.typeAttributeConversions) {}
52  conversions = other.conversions;
53  sourceMaterializations = other.sourceMaterializations;
54  targetMaterializations = other.targetMaterializations;
55  typeAttributeConversions = other.typeAttributeConversions;
56  return *this;
57  }
58 
59  /// This class provides all of the information necessary to convert a type
60  /// signature.
62  public:
63  SignatureConversion(unsigned numOrigInputs)
64  : remappedInputs(numOrigInputs) {}
65 
66  /// This struct represents a range of new types or a range of values that
67  /// remaps an existing signature input.
68  struct InputMapping {
69  size_t inputNo, size;
71 
72  /// Return "true" if this input was replaces with one or multiple values.
73  bool replacedWithValues() const { return !replacementValues.empty(); }
74  };
75 
76  /// Return the argument types for the new signature.
77  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
78 
79  /// Get the input mapping for the given argument.
80  std::optional<InputMapping> getInputMapping(unsigned input) const {
81  return remappedInputs[input];
82  }
83 
84  //===------------------------------------------------------------------===//
85  // Conversion Hooks
86  //===------------------------------------------------------------------===//
87 
88  /// Remap an input of the original signature with a new set of types. The
89  /// new types are appended to the new signature conversion.
90  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
91 
92  /// Append new input types to the signature conversion, this should only be
93  /// used if the new types are not intended to remap an existing input.
94  void addInputs(ArrayRef<Type> types);
95 
96  /// Remap an input of the original signature to `replacements`
97  /// values. This drops the original argument.
98  void remapInput(unsigned origInputNo, ArrayRef<Value> replacements);
99 
100  private:
101  /// Remap an input of the original signature with a range of types in the
102  /// new signature.
103  void remapInput(unsigned origInputNo, unsigned newInputNo,
104  unsigned newInputCount = 1);
105 
106  /// The remapping information for each of the original arguments.
107  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
108 
109  /// The set of new argument types.
110  SmallVector<Type, 4> argTypes;
111  };
112 
113  /// The general result of a type attribute conversion callback, allowing
114  /// for early termination. The default constructor creates the na case.
116  public:
117  constexpr AttributeConversionResult() : impl() {}
118  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
119 
121  static AttributeConversionResult na();
123 
124  bool hasResult() const;
125  bool isNa() const;
126  bool isAbort() const;
127 
128  Attribute getResult() const;
129 
130  private:
131  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
132 
133  llvm::PointerIntPair<Attribute, 2> impl;
134  // Note that na is 0 so that we can use PointerIntPair's default
135  // constructor.
136  static constexpr unsigned naTag = 0;
137  static constexpr unsigned resultTag = 1;
138  static constexpr unsigned abortTag = 2;
139  };
140 
141  /// Register a conversion function. A conversion function must be convertible
142  /// to any of the following forms (where `T` is a class derived from `Type`):
143  ///
144  /// * std::optional<Type>(T)
145  /// - This form represents a 1-1 type conversion. It should return nullptr
146  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
147  /// the converter is allowed to try another conversion function to
148  /// perform the conversion.
149  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
150  /// - This form represents a 1-N type conversion. It should return
151  /// `failure` or `std::nullopt` to signify a failed conversion. If the
152  /// new set of types is empty, the type is removed and any usages of the
153  /// existing value are expected to be removed during conversion. If
154  /// `std::nullopt` is returned, the converter is allowed to try another
155  /// conversion function to perform the conversion.
156  ///
157  /// Note: When attempting to convert a type, e.g. via 'convertType', the
158  /// mostly recently added conversions will be invoked first.
159  template <typename FnT, typename T = typename llvm::function_traits<
160  std::decay_t<FnT>>::template arg_t<0>>
161  void addConversion(FnT &&callback) {
162  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
163  }
164 
165  /// All of the following materializations require function objects that are
166  /// convertible to the following form:
167  /// `Value(OpBuilder &, T, ValueRange, Location)`,
168  /// where `T` is any subclass of `Type`. This function is responsible for
169  /// creating an operation, using the OpBuilder and Location provided, that
170  /// "casts" a range of values into a single value of the given type `T`. It
171  /// must return a Value of the type `T` on success and `nullptr` if
172  /// it failed but other materialization should be attempted. Materialization
173  /// functions must be provided when a type conversion may persist after the
174  /// conversion has finished.
175  ///
176  /// Note: Target materializations may optionally accept an additional Type
177  /// parameter, which is the original type of the SSA value. Furthermore, `T`
178  /// can be a TypeRange; in that case, the function must return a
179  /// SmallVector<Value>.
180 
181  /// This method registers a materialization that will be called when
182  /// converting a replacement value back to its original source type.
183  /// This is used when some uses of the original value persist beyond the main
184  /// conversion.
185  template <typename FnT, typename T = typename llvm::function_traits<
186  std::decay_t<FnT>>::template arg_t<1>>
187  void addSourceMaterialization(FnT &&callback) {
188  sourceMaterializations.emplace_back(
189  wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
190  }
191 
192  /// This method registers a materialization that will be called when
193  /// converting a value to a target type according to a pattern's type
194  /// converter.
195  ///
196  /// Note: Target materializations can optionally inspect the "original"
197  /// type. This type may be different from the type of the input value.
198  /// For example, let's assume that a conversion pattern "P1" replaced an SSA
199  /// value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion
200  /// pattern "P2" matches an op that has "v1" as an operand. Let's furthermore
201  /// assume that "P2" determines that the converted target type of "t1" is
202  /// "t3", which may be different from "t2". In this example, the target
203  /// materialization will be invoked with: outputType = "t3", inputs = "v2",
204  /// originalType = "t1". Note that the original type "t1" cannot be recovered
205  /// from just "t3" and "v2"; that's why the originalType parameter exists.
206  ///
207  /// Note: During a 1:N conversion, the result types can be a TypeRange. In
208  /// that case the materialization produces a SmallVector<Value>.
209  template <typename FnT, typename T = typename llvm::function_traits<
210  std::decay_t<FnT>>::template arg_t<1>>
211  void addTargetMaterialization(FnT &&callback) {
212  targetMaterializations.emplace_back(
213  wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
214  }
215 
216  /// Register a conversion function for attributes within types. Type
217  /// converters may call this function in order to allow hoking into the
218  /// translation of attributes that exist within types. For example, a type
219  /// converter for the `memref` type could use these conversions to convert
220  /// memory spaces or layouts in an extensible way.
221  ///
222  /// The conversion functions take a non-null Type or subclass of Type and a
223  /// non-null Attribute (or subclass of Attribute), and returns a
224  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
225  /// which may be `nullptr`, representing the conversion's success,
226  /// `AttributeConversionResult::na()` (the default empty value), indicating
227  /// that the conversion function did not apply and that further conversion
228  /// functions should be checked, or `AttributeConversionResult::abort()`
229  /// indicating that the conversion process should be aborted.
230  ///
231  /// Registered conversion functions are callled in the reverse of the order in
232  /// which they were registered.
233  template <
234  typename FnT,
235  typename T =
236  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
237  typename A =
238  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
239  void addTypeAttributeConversion(FnT &&callback) {
240  registerTypeAttributeConversion(
241  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
242  }
243 
244  /// Convert the given type. This function should return failure if no valid
245  /// conversion exists, success otherwise. If the new set of types is empty,
246  /// the type is removed and any usages of the existing value are expected to
247  /// be removed during conversion.
248  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
249 
250  /// This hook simplifies defining 1-1 type conversions. This function returns
251  /// the type to convert to on success, and a null type on failure.
252  Type convertType(Type t) const;
253 
254  /// Attempts a 1-1 type conversion, expecting the result type to be
255  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
256  /// and a null type on conversion or cast failure.
257  template <typename TargetType>
258  TargetType convertType(Type t) const {
259  return dyn_cast_or_null<TargetType>(convertType(t));
260  }
261 
262  /// Convert the given set of types, filling 'results' as necessary. This
263  /// returns failure if the conversion of any of the types fails, success
264  /// otherwise.
265  LogicalResult convertTypes(TypeRange types,
266  SmallVectorImpl<Type> &results) const;
267 
268  /// Return true if the given type is legal for this type converter, i.e. the
269  /// type converts to itself.
270  bool isLegal(Type type) const;
271 
272  /// Return true if all of the given types are legal for this type converter.
273  template <typename RangeT>
274  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
275  !std::is_convertible<RangeT, Operation *>::value,
276  bool>
277  isLegal(RangeT &&range) const {
278  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
279  }
280  /// Return true if the given operation has legal operand and result types.
281  bool isLegal(Operation *op) const;
282 
283  /// Return true if the types of block arguments within the region are legal.
284  bool isLegal(Region *region) const;
285 
286  /// Return true if the inputs and outputs of the given function type are
287  /// legal.
288  bool isSignatureLegal(FunctionType ty) const;
289 
290  /// This method allows for converting a specific argument of a signature. It
291  /// takes as inputs the original argument input number, type.
292  /// On success, it populates 'result' with any new mappings.
293  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
294  SignatureConversion &result) const;
295  LogicalResult convertSignatureArgs(TypeRange types,
296  SignatureConversion &result,
297  unsigned origInputOffset = 0) const;
298 
299  /// This function converts the type signature of the given block, by invoking
300  /// 'convertSignatureArg' for each argument. This function should return a
301  /// valid conversion for the signature on success, std::nullopt otherwise.
302  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
303 
304  /// Materialize a conversion from a set of types into one result type by
305  /// generating a cast sequence of some kind. See the respective
306  /// `add*Materialization` for more information on the context for these
307  /// methods.
309  Type resultType, ValueRange inputs) const;
311  Type resultType, ValueRange inputs,
312  Type originalType = {}) const;
313  SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
314  Location loc,
315  TypeRange resultType,
316  ValueRange inputs,
317  Type originalType = {}) const;
318 
319  /// Convert an attribute present `attr` from within the type `type` using
320  /// the registered conversion functions. If no applicable conversion has been
321  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
322  /// is a valid return value for this function.
323  std::optional<Attribute> convertTypeAttribute(Type type,
324  Attribute attr) const;
325 
326 private:
327  /// The signature of the callback used to convert a type. If the new set of
328  /// types is empty, the type is removed and any usages of the existing value
329  /// are expected to be removed during conversion.
330  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
331  Type, SmallVectorImpl<Type> &)>;
332 
333  /// The signature of the callback used to materialize a source conversion.
334  ///
335  /// Arguments: builder, result type, inputs, location
336  using SourceMaterializationCallbackFn =
337  std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
338 
339  /// The signature of the callback used to materialize a target conversion.
340  ///
341  /// Arguments: builder, result types, inputs, location, original type
342  using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
343  OpBuilder &, TypeRange, ValueRange, Location, Type)>;
344 
345  /// The signature of the callback used to convert a type attribute.
346  using TypeAttributeConversionCallbackFn =
347  std::function<AttributeConversionResult(Type, Attribute)>;
348 
349  /// Generate a wrapper for the given callback. This allows for accepting
350  /// different callback forms, that all compose into a single version.
351  /// With callback of form: `std::optional<Type>(T)`
352  template <typename T, typename FnT>
353  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
354  wrapCallback(FnT &&callback) const {
355  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
356  T type, SmallVectorImpl<Type> &results) {
357  if (std::optional<Type> resultOpt = callback(type)) {
358  bool wasSuccess = static_cast<bool>(*resultOpt);
359  if (wasSuccess)
360  results.push_back(*resultOpt);
361  return std::optional<LogicalResult>(success(wasSuccess));
362  }
363  return std::optional<LogicalResult>();
364  });
365  }
366  /// With callback of form: `std::optional<LogicalResult>(
367  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
368  template <typename T, typename FnT>
369  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
370  ConversionCallbackFn>
371  wrapCallback(FnT &&callback) const {
372  return [callback = std::forward<FnT>(callback)](
373  Type type,
374  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
375  T derivedType = dyn_cast<T>(type);
376  if (!derivedType)
377  return std::nullopt;
378  return callback(derivedType, results);
379  };
380  }
381 
382  /// Register a type conversion.
383  void registerConversion(ConversionCallbackFn callback) {
384  conversions.emplace_back(std::move(callback));
385  cachedDirectConversions.clear();
386  cachedMultiConversions.clear();
387  }
388 
389  /// Generate a wrapper for the given source materialization callback. The
390  /// callback may take any subclass of `Type` and the wrapper will check for
391  /// the target type to be of the expected class before calling the callback.
392  template <typename T, typename FnT>
393  SourceMaterializationCallbackFn
394  wrapSourceMaterialization(FnT &&callback) const {
395  return [callback = std::forward<FnT>(callback)](
396  OpBuilder &builder, Type resultType, ValueRange inputs,
397  Location loc) -> Value {
398  if (T derivedType = dyn_cast<T>(resultType))
399  return callback(builder, derivedType, inputs, loc);
400  return Value();
401  };
402  }
403 
404  /// Generate a wrapper for the given target materialization callback.
405  /// The callback may take any subclass of `Type` and the wrapper will check
406  /// for the target type to be of the expected class before calling the
407  /// callback.
408  ///
409  /// With callback of form:
410  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
411  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
412  template <typename T, typename FnT>
413  std::enable_if_t<
414  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
415  TargetMaterializationCallbackFn>
416  wrapTargetMaterialization(FnT &&callback) const {
417  return [callback = std::forward<FnT>(callback)](
418  OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
419  Location loc, Type originalType) -> SmallVector<Value> {
420  SmallVector<Value> result;
421  if constexpr (std::is_same<T, TypeRange>::value) {
422  // This is a 1:N target materialization. Return the produces values
423  // directly.
424  result = callback(builder, resultTypes, inputs, loc, originalType);
425  } else if constexpr (std::is_assignable<Type, T>::value) {
426  // This is a 1:1 target materialization. Invoke the callback only if a
427  // single SSA value is requested.
428  if (resultTypes.size() == 1) {
429  // Invoke the callback only if the type class of the callback matches
430  // the requested result type.
431  if (T derivedType = dyn_cast<T>(resultTypes.front())) {
432  // 1:1 materializations produce single values, but we store 1:N
433  // target materialization functions in the type converter. Wrap the
434  // result value in a SmallVector<Value>.
435  Value val =
436  callback(builder, derivedType, inputs, loc, originalType);
437  if (val)
438  result.push_back(val);
439  }
440  }
441  } else {
442  static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
443  }
444  return result;
445  };
446  }
447  /// With callback of form:
448  /// - Value(OpBuilder &, T, ValueRange, Location)
449  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
450  template <typename T, typename FnT>
451  std::enable_if_t<
452  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
453  TargetMaterializationCallbackFn>
454  wrapTargetMaterialization(FnT &&callback) const {
455  return wrapTargetMaterialization<T>(
456  [callback = std::forward<FnT>(callback)](
457  OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
458  Type originalType) {
459  return callback(builder, resultTypes, inputs, loc);
460  });
461  }
462 
463  /// Generate a wrapper for the given memory space conversion callback. The
464  /// callback may take any subclass of `Attribute` and the wrapper will check
465  /// for the target attribute to be of the expected class before calling the
466  /// callback.
467  template <typename T, typename A, typename FnT>
468  TypeAttributeConversionCallbackFn
469  wrapTypeAttributeConversion(FnT &&callback) const {
470  return [callback = std::forward<FnT>(callback)](
471  Type type, Attribute attr) -> AttributeConversionResult {
472  if (T derivedType = dyn_cast<T>(type)) {
473  if (A derivedAttr = dyn_cast_or_null<A>(attr))
474  return callback(derivedType, derivedAttr);
475  }
477  };
478  }
479 
480  /// Register a memory space conversion, clearing caches.
481  void
482  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
483  typeAttributeConversions.emplace_back(std::move(callback));
484  // Clear type conversions in case a memory space is lingering inside.
485  cachedDirectConversions.clear();
486  cachedMultiConversions.clear();
487  }
488 
489  /// The set of registered conversion functions.
490  SmallVector<ConversionCallbackFn, 4> conversions;
491 
492  /// The list of registered materialization functions.
493  SmallVector<SourceMaterializationCallbackFn, 2> sourceMaterializations;
494  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
495 
496  /// The list of registered type attribute conversion functions.
497  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
498 
499  /// A set of cached conversions to avoid recomputing in the common case.
500  /// Direct 1-1 conversions are the most common, so this cache stores the
501  /// successful 1-1 conversions as well as all failed conversions.
502  mutable DenseMap<Type, Type> cachedDirectConversions;
503  /// This cache stores the successful 1->N conversions, where N != 1.
504  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
505  /// A mutex used for cache access
506  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
507 };
508 
509 //===----------------------------------------------------------------------===//
510 // Conversion Patterns
511 //===----------------------------------------------------------------------===//
512 
513 /// Base class for the conversion patterns. This pattern class enables type
514 /// conversions, and other uses specific to the conversion framework. As such,
515 /// patterns of this type can only be used with the 'apply*' methods below.
517 public:
520 
521  /// Hook for derived classes to implement combined matching and rewriting.
522  /// This overload supports only 1:1 replacements. The 1:N overload is called
523  /// by the driver. By default, it calls this 1:1 overload or reports a fatal
524  /// error if 1:N replacements were found.
525  virtual LogicalResult
527  ConversionPatternRewriter &rewriter) const {
528  llvm_unreachable("matchAndRewrite is not implemented");
529  }
530 
531  /// Hook for derived classes to implement combined matching and rewriting.
532  /// This overload supports 1:N replacements.
533  virtual LogicalResult
535  ConversionPatternRewriter &rewriter) const {
536  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
537  }
538 
539  /// Attempt to match and rewrite the IR root at the specified operation.
540  LogicalResult matchAndRewrite(Operation *op,
541  PatternRewriter &rewriter) const final;
542 
543  /// Return the type converter held by this pattern, or nullptr if the pattern
544  /// does not require type conversion.
545  const TypeConverter *getTypeConverter() const { return typeConverter; }
546 
547  template <typename ConverterTy>
548  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
549  const ConverterTy *>
551  return static_cast<const ConverterTy *>(typeConverter);
552  }
553 
554 protected:
555  /// See `RewritePattern::RewritePattern` for information on the other
556  /// available constructors.
557  using RewritePattern::RewritePattern;
558  /// Construct a conversion pattern with the given converter, and forward the
559  /// remaining arguments to RewritePattern.
560  template <typename... Args>
562  : RewritePattern(std::forward<Args>(args)...),
564 
565  /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
566  /// try to extract the single value of each range to construct a the inputs
567  /// for a 1:1 adaptor.
568  ///
569  /// This function produces a fatal error if at least one range has 0 or
570  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
573 
574 protected:
575  /// An optional type converter for use by this pattern.
576  const TypeConverter *typeConverter = nullptr;
577 };
578 
579 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
580 /// matching and rewriting against an instance of a derived operation class as
581 /// opposed to a raw Operation.
582 template <typename SourceOp>
584 public:
585  using OpAdaptor = typename SourceOp::Adaptor;
587  typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
588 
590  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
592  PatternBenefit benefit = 1)
593  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
594  context) {}
595 
596  /// Wrappers around the ConversionPattern methods that pass the derived op
597  /// type.
598  LogicalResult
600  ConversionPatternRewriter &rewriter) const final {
601  auto sourceOp = cast<SourceOp>(op);
602  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
603  }
604  LogicalResult
606  ConversionPatternRewriter &rewriter) const final {
607  auto sourceOp = cast<SourceOp>(op);
608  return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
609  rewriter);
610  }
611 
612  /// Methods that operate on the SourceOp type. One of these must be
613  /// overridden by the derived pattern class.
614  virtual LogicalResult
615  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
616  ConversionPatternRewriter &rewriter) const {
617  llvm_unreachable("matchAndRewrite is not implemented");
618  }
619  virtual LogicalResult
620  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
621  ConversionPatternRewriter &rewriter) const {
622  SmallVector<Value> oneToOneOperands =
623  getOneToOneAdaptorOperands(adaptor.getOperands());
624  return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
625  }
626 
627 private:
629 };
630 
631 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
632 /// allows for matching and rewriting against an instance of an OpInterface
633 /// class as opposed to a raw Operation.
634 template <typename SourceOp>
636 public:
639  SourceOp::getInterfaceID(), benefit, context) {}
641  MLIRContext *context, PatternBenefit benefit = 1)
643  SourceOp::getInterfaceID(), benefit, context) {}
644 
645  /// Wrappers around the ConversionPattern methods that pass the derived op
646  /// type.
647  LogicalResult
649  ConversionPatternRewriter &rewriter) const final {
650  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
651  }
652  LogicalResult
654  ConversionPatternRewriter &rewriter) const final {
655  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
656  }
657 
658  /// Methods that operate on the SourceOp type. One of these must be
659  /// overridden by the derived pattern class.
660  virtual LogicalResult
661  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
662  ConversionPatternRewriter &rewriter) const {
663  llvm_unreachable("matchAndRewrite is not implemented");
664  }
665  virtual LogicalResult
666  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
667  ConversionPatternRewriter &rewriter) const {
668  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
669  }
670 
671 private:
673 };
674 
675 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
676 /// for matching and rewriting against instances of an operation that possess a
677 /// given trait.
678 template <template <typename> class TraitType>
680 public:
683  TypeID::get<TraitType>(), benefit, context) {}
685  MLIRContext *context, PatternBenefit benefit = 1)
687  TypeID::get<TraitType>(), benefit, context) {}
688 };
689 
690 /// Generic utility to convert op result types according to type converter
691 /// without knowing exact op type.
692 /// Clones existing op with new result types and returns it.
693 FailureOr<Operation *>
694 convertOpResultTypes(Operation *op, ValueRange operands,
695  const TypeConverter &converter,
696  ConversionPatternRewriter &rewriter);
697 
698 /// Add a pattern to the given pattern list to convert the signature of a
699 /// FunctionOpInterface op with the given type converter. This only supports
700 /// ops which use FunctionType to represent their type.
702  StringRef functionLikeOpName, RewritePatternSet &patterns,
703  const TypeConverter &converter);
704 
705 template <typename FuncOpT>
707  RewritePatternSet &patterns, const TypeConverter &converter) {
708  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
709  patterns, converter);
710 }
711 
713  RewritePatternSet &patterns, const TypeConverter &converter);
714 
715 //===----------------------------------------------------------------------===//
716 // Conversion PatternRewriter
717 //===----------------------------------------------------------------------===//
718 
719 namespace detail {
720 struct ConversionPatternRewriterImpl;
721 } // namespace detail
722 
723 /// This class implements a pattern rewriter for use with ConversionPatterns. It
724 /// extends the base PatternRewriter and provides special conversion specific
725 /// hooks.
727 public:
729 
730  /// Apply a signature conversion to given block. This replaces the block with
731  /// a new block containing the updated signature. The operations of the given
732  /// block are inlined into the newly-created block, which is returned.
733  ///
734  /// If no block argument types are changing, the original block will be
735  /// left in place and returned.
736  ///
737  /// A signature converison must be provided. (Type converters can construct
738  /// a signature conversion with `convertBlockSignature`.)
739  ///
740  /// Optionally, a type converter can be provided to build materializations.
741  /// Note: If no type converter was provided or the type converter does not
742  /// specify any suitable source/target materialization rules, the dialect
743  /// conversion may fail to legalize unresolved materializations.
744  Block *
747  const TypeConverter *converter = nullptr);
748 
749  /// Apply a signature conversion to each block in the given region. This
750  /// replaces each block with a new block containing the updated signature. If
751  /// an updated signature would match the current signature, the respective
752  /// block is left in place as is. (See `applySignatureConversion` for
753  /// details.) The new entry block of the region is returned.
754  ///
755  /// SignatureConversions are computed with the specified type converter.
756  /// This function returns "failure" if the type converter failed to compute
757  /// a SignatureConversion for at least one block.
758  ///
759  /// Optionally, a special SignatureConversion can be specified for the entry
760  /// block. This is because the types of the entry block arguments are often
761  /// tied semantically to the operation.
762  FailureOr<Block *> convertRegionTypes(
763  Region *region, const TypeConverter &converter,
764  TypeConverter::SignatureConversion *entryConversion = nullptr);
765 
766  /// Replace all the uses of the block argument `from` with value `to`.
768 
769  /// Return the converted value of 'key' with a type defined by the type
770  /// converter of the currently executing pattern. Return nullptr in the case
771  /// of failure, the remapped value otherwise.
773 
774  /// Return the converted values that replace 'keys' with types defined by the
775  /// type converter of the currently executing pattern. Returns failure if the
776  /// remap failed, success otherwise.
777  LogicalResult getRemappedValues(ValueRange keys,
778  SmallVectorImpl<Value> &results);
779 
780  //===--------------------------------------------------------------------===//
781  // PatternRewriter Hooks
782  //===--------------------------------------------------------------------===//
783 
784  /// Indicate that the conversion rewriter can recover from rewrite failure.
785  /// Recovery is supported via rollback, allowing for continued processing of
786  /// patterns even if a failure is encountered during the rewrite step.
787  bool canRecoverFromRewriteFailure() const override { return true; }
788 
789  /// Replace the given operation with the new values. The number of op results
790  /// and replacement values must match. The types may differ: the dialect
791  /// conversion driver will reconcile any surviving type mismatches at the end
792  /// of the conversion process with source materializations. The given
793  /// operation is erased.
794  void replaceOp(Operation *op, ValueRange newValues) override;
795 
796  /// Replace the given operation with the results of the new op. The number of
797  /// op results must match. The types may differ: the dialect conversion
798  /// driver will reconcile any surviving type mismatches at the end of the
799  /// conversion process with source materializations. The original operation
800  /// is erased.
801  void replaceOp(Operation *op, Operation *newOp) override;
802 
803  /// Replace the given operation with the new value ranges. The number of op
804  /// results and value ranges must match. The given operation is erased.
806  SmallVector<SmallVector<Value>> &&newValues);
807  template <typename RangeT = ValueRange>
810  llvm::to_vector_of<SmallVector<Value>>(newValues));
811  }
812  template <typename RangeT>
813  void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
815  ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
816  }
817 
818  /// PatternRewriter hook for erasing a dead operation. The uses of this
819  /// operation *must* be made dead by the end of the conversion process,
820  /// otherwise an assert will be issued.
821  void eraseOp(Operation *op) override;
822 
823  /// PatternRewriter hook for erase all operations in a block. This is not yet
824  /// implemented for dialect conversion.
825  void eraseBlock(Block *block) override;
826 
827  /// PatternRewriter hook for inlining the ops of a block into another block.
828  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
829  ValueRange argValues = std::nullopt) override;
831 
832  /// PatternRewriter hook for updating the given operation in-place.
833  /// Note: These methods only track updates to the given operation itself,
834  /// and not nested regions. Updates to regions will still require notification
835  /// through other more specific hooks above.
836  void startOpModification(Operation *op) override;
837 
838  /// PatternRewriter hook for updating the given operation in-place.
839  void finalizeOpModification(Operation *op) override;
840 
841  /// PatternRewriter hook for updating the given operation in-place.
842  void cancelOpModification(Operation *op) override;
843 
844  /// Return a reference to the internal implementation.
846 
847 private:
848  // Allow OperationConverter to construct new rewriters.
849  friend struct OperationConverter;
850 
851  /// Conversion pattern rewriters must not be used outside of dialect
852  /// conversions. They apply some IR rewrites in a delayed fashion and could
853  /// bring the IR into an inconsistent state when used standalone.
855  const ConversionConfig &config);
856 
857  // Hide unsupported pattern rewriter API.
859 
860  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
861 };
862 
863 //===----------------------------------------------------------------------===//
864 // ConversionTarget
865 //===----------------------------------------------------------------------===//
866 
867 /// This class describes a specific conversion target.
869 public:
870  /// This enumeration corresponds to the specific action to take when
871  /// considering an operation legal for this conversion target.
872  enum class LegalizationAction {
873  /// The target supports this operation.
874  Legal,
875 
876  /// This operation has dynamic legalization constraints that must be checked
877  /// by the target.
878  Dynamic,
879 
880  /// The target explicitly does not support this operation.
881  Illegal,
882  };
883 
884  /// A structure containing additional information describing a specific legal
885  /// operation instance.
886  struct LegalOpDetails {
887  /// A flag that indicates if this operation is 'recursively' legal. This
888  /// means that if an operation is legal, either statically or dynamically,
889  /// all of the operations nested within are also considered legal.
890  bool isRecursivelyLegal = false;
891  };
892 
893  /// The signature of the callback used to determine if an operation is
894  /// dynamically legal on the target.
896  std::function<std::optional<bool>(Operation *)>;
897 
898  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
899  virtual ~ConversionTarget() = default;
900 
901  //===--------------------------------------------------------------------===//
902  // Legality Registration
903  //===--------------------------------------------------------------------===//
904 
905  /// Register a legality action for the given operation.
907  template <typename OpT>
909  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
910  }
911 
912  /// Register the given operations as legal.
915  }
916  template <typename OpT>
917  void addLegalOp() {
918  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
919  }
920  template <typename OpT, typename OpT2, typename... OpTs>
921  void addLegalOp() {
922  addLegalOp<OpT>();
923  addLegalOp<OpT2, OpTs...>();
924  }
925 
926  /// Register the given operation as dynamically legal and set the dynamic
927  /// legalization callback to the one provided.
929  const DynamicLegalityCallbackFn &callback) {
931  setLegalityCallback(op, callback);
932  }
933  template <typename OpT>
935  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
936  callback);
937  }
938  template <typename OpT, typename OpT2, typename... OpTs>
940  addDynamicallyLegalOp<OpT>(callback);
941  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
942  }
943  template <typename OpT, class Callable>
944  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
945  addDynamicallyLegalOp(Callable &&callback) {
946  addDynamicallyLegalOp<OpT>(
947  [=](Operation *op) { return callback(cast<OpT>(op)); });
948  }
949 
950  /// Register the given operation as illegal, i.e. this operation is known to
951  /// not be supported by this target.
954  }
955  template <typename OpT>
956  void addIllegalOp() {
957  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
958  }
959  template <typename OpT, typename OpT2, typename... OpTs>
960  void addIllegalOp() {
961  addIllegalOp<OpT>();
962  addIllegalOp<OpT2, OpTs...>();
963  }
964 
965  /// Mark an operation, that *must* have either been set as `Legal` or
966  /// `DynamicallyLegal`, as being recursively legal. This means that in
967  /// addition to the operation itself, all of the operations nested within are
968  /// also considered legal. An optional dynamic legality callback may be
969  /// provided to mark subsets of legal instances as recursively legal.
971  const DynamicLegalityCallbackFn &callback);
972  template <typename OpT>
974  OperationName opName(OpT::getOperationName(), &ctx);
975  markOpRecursivelyLegal(opName, callback);
976  }
977  template <typename OpT, typename OpT2, typename... OpTs>
979  markOpRecursivelyLegal<OpT>(callback);
980  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
981  }
982  template <typename OpT, class Callable>
983  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
984  markOpRecursivelyLegal(Callable &&callback) {
985  markOpRecursivelyLegal<OpT>(
986  [=](Operation *op) { return callback(cast<OpT>(op)); });
987  }
988 
989  /// Register a legality action for the given dialects.
990  void setDialectAction(ArrayRef<StringRef> dialectNames,
991  LegalizationAction action);
992 
993  /// Register the operations of the given dialects as legal.
994  template <typename... Names>
995  void addLegalDialect(StringRef name, Names... names) {
996  SmallVector<StringRef, 2> dialectNames({name, names...});
998  }
999  template <typename... Args>
1001  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1003  }
1004 
1005  /// Register the operations of the given dialects as dynamically legal, i.e.
1006  /// requiring custom handling by the callback.
1007  template <typename... Names>
1009  StringRef name, Names... names) {
1010  SmallVector<StringRef, 2> dialectNames({name, names...});
1012  setLegalityCallback(dialectNames, callback);
1013  }
1014  template <typename... Args>
1016  addDynamicallyLegalDialect(std::move(callback),
1017  Args::getDialectNamespace()...);
1018  }
1019 
1020  /// Register unknown operations as dynamically legal. For operations(and
1021  /// dialects) that do not have a set legalization action, treat them as
1022  /// dynamically legal and invoke the given callback.
1024  setLegalityCallback(fn);
1025  }
1026 
1027  /// Register the operations of the given dialects as illegal, i.e.
1028  /// operations of this dialect are not supported by the target.
1029  template <typename... Names>
1030  void addIllegalDialect(StringRef name, Names... names) {
1031  SmallVector<StringRef, 2> dialectNames({name, names...});
1033  }
1034  template <typename... Args>
1036  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1038  }
1039 
1040  //===--------------------------------------------------------------------===//
1041  // Legality Querying
1042  //===--------------------------------------------------------------------===//
1043 
1044  /// Get the legality action for the given operation.
1045  std::optional<LegalizationAction> getOpAction(OperationName op) const;
1046 
1047  /// If the given operation instance is legal on this target, a structure
1048  /// containing legality information is returned. If the operation is not
1049  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1050  /// legality wasn't registered by user or dynamic legality callbacks returned
1051  /// None.
1052  ///
1053  /// Note: Legality is actually a 4-state: Legal(recursive=true),
1054  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1055  /// either as Legal or Illegal depending on context.
1056  std::optional<LegalOpDetails> isLegal(Operation *op) const;
1057 
1058  /// Returns true is operation instance is illegal on this target. Returns
1059  /// false if operation is legal, operation legality wasn't registered by user
1060  /// or dynamic legality callbacks returned None.
1061  bool isIllegal(Operation *op) const;
1062 
1063 private:
1064  /// Set the dynamic legality callback for the given operation.
1065  void setLegalityCallback(OperationName name,
1066  const DynamicLegalityCallbackFn &callback);
1067 
1068  /// Set the dynamic legality callback for the given dialects.
1069  void setLegalityCallback(ArrayRef<StringRef> dialects,
1070  const DynamicLegalityCallbackFn &callback);
1071 
1072  /// Set the dynamic legality callback for the unknown ops.
1073  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1074 
1075  /// The set of information that configures the legalization of an operation.
1076  struct LegalizationInfo {
1077  /// The legality action this operation was given.
1079 
1080  /// If some legal instances of this operation may also be recursively legal.
1081  bool isRecursivelyLegal = false;
1082 
1083  /// The legality callback if this operation is dynamically legal.
1084  DynamicLegalityCallbackFn legalityFn;
1085  };
1086 
1087  /// Get the legalization information for the given operation.
1088  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1089 
1090  /// A deterministic mapping of operation name and its respective legality
1091  /// information.
1092  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1093 
1094  /// A set of legality callbacks for given operation names that are used to
1095  /// check if an operation instance is recursively legal.
1096  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1097 
1098  /// A deterministic mapping of dialect name to the specific legality action to
1099  /// take.
1100  llvm::StringMap<LegalizationAction> legalDialects;
1101 
1102  /// A set of dynamic legality callbacks for given dialect names.
1103  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1104 
1105  /// An optional legality callback for unknown operations.
1106  DynamicLegalityCallbackFn unknownLegalityFn;
1107 
1108  /// The current context this target applies to.
1109  MLIRContext &ctx;
1110 };
1111 
1112 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1113 //===----------------------------------------------------------------------===//
1114 // PDL Configuration
1115 //===----------------------------------------------------------------------===//
1116 
1117 /// A PDL configuration that is used to supported dialect conversion
1118 /// functionality.
1120  : public PDLPatternConfigBase<PDLConversionConfig> {
1121 public:
1122  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1123  ~PDLConversionConfig() final = default;
1124 
1125  /// Return the type converter used by this configuration, which may be nullptr
1126  /// if no type conversions are expected.
1127  const TypeConverter *getTypeConverter() const { return converter; }
1128 
1129  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1130  /// pattern.
1131  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1132  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1133 
1134 private:
1135  /// An optional type converter to use for the pattern.
1136  const TypeConverter *converter;
1137 };
1138 
1139 /// Register the dialect conversion PDL functions with the given pattern set.
1140 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1141 
1142 #else
1143 
1144 // Stubs for when PDL in rewriting is not enabled.
1145 
1146 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1147 
1148 class PDLConversionConfig final {
1149 public:
1150  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1151 };
1152 
1153 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1154 
1155 //===----------------------------------------------------------------------===//
1156 // ConversionConfig
1157 //===----------------------------------------------------------------------===//
1158 
1159 /// Dialect conversion configuration.
1161  /// An optional callback used to notify about match failure diagnostics during
1162  /// the conversion. Diagnostics reported to this callback may only be
1163  /// available in debug mode.
1165 
1166  /// Partial conversion only. All operations that are found not to be
1167  /// legalizable are placed in this set. (Note that if there is an op
1168  /// explicitly marked as illegal, the conversion terminates and the set will
1169  /// not necessarily be complete.)
1171 
1172  /// Analysis conversion only. All operations that are found to be legalizable
1173  /// are placed in this set. Note that no actual rewrites are applied to the
1174  /// IR during an analysis conversion and only pre-existing operations are
1175  /// added to the set.
1177 
1178  /// An optional listener that is notified about all IR modifications in case
1179  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1180  /// modifications are visible (i.e., they were all rolled back), or if the
1181  /// dialect conversion is an "analysis conversion", no notifications are
1182  /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1183  ///
1184  /// Note: Notifications are sent in a delayed fashion, when the dialect
1185  /// conversion is guaranteed to succeed. At that point, some IR modifications
1186  /// may already have been materialized. Consequently, operations/blocks that
1187  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1188  /// guaranteed to be valid pointers and accessing op names is allowed. But
1189  /// there are no guarantees about the state of ops/blocks at the time that a
1190  /// callback is triggered.)
1191  ///
1192  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1193  /// and inserted, and later moved to another block. (Moving ops also triggers
1194  /// "notifyOperationInserted".)
1195  ///
1196  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1197  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1198  ///
1199  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1200  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1201  /// immediately performed by the dialect conversion (and rolled back upon
1202  /// failure).
1203  //
1204  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1205  // callback, the previous region/block is provided to the callback, but not
1206  // the iterator pointing to the exact location within the region/block. That
1207  // is because these notifications are sent with a delay (after the IR has
1208  // already been modified) and iterators into past IR state cannot be
1209  // represented at the moment.
1211 
1212  /// If set to "true", the dialect conversion attempts to build source/target
1213  /// materializations through the type converter API in lieu of
1214  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1215  /// at least one materialization could not be built.
1216  ///
1217  /// If set to "false", the dialect conversion does not build any custom
1218  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1219  /// ops to ensure that the resulting IR is valid.
1221 
1222  /// If set to "true", pattern rollback is allowed. The conversion driver
1223  /// rolls back IR modifications in the following situations.
1224  ///
1225  /// 1. Pattern implementation returns "failure" after modifying IR.
1226  /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
1227  /// and cannot be legalized by subsequent foldings / pattern applications.
1228  ///
1229  /// If set to "false", the conversion driver will produce an LLVM fatal error
1230  /// instead of rolling back IR modifications. Moreover, in case of a failed
1231  /// conversion, the original IR is not restored. The resulting IR may be a
1232  /// mix of original and rewritten IR. (Same as a failed greedy pattern
1233  /// rewrite.)
1234  ///
1235  /// Note: This flag was added in preparation of the One-Shot Dialect
1236  /// Conversion refactoring, which will remove the ability to roll back IR
1237  /// modifications from the conversion driver. Use this flag to ensure that
1238  /// your patterns do not trigger any IR rollbacks. For details, see
1239  /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
1241 };
1242 
1243 //===----------------------------------------------------------------------===//
1244 // Reconcile Unrealized Casts
1245 //===----------------------------------------------------------------------===//
1246 
1247 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1248 /// left-over ops in `remainingCastOps` (if provided).
1249 ///
1250 /// This function processes cast ops in a worklist-driven fashion. For each
1251 /// cast op, if the chain of input casts eventually reaches a cast op where the
1252 /// input types match the output types of the matched op, replace the matched
1253 /// op with the inputs.
1254 ///
1255 /// Example:
1256 /// %1 = unrealized_conversion_cast %0 : !A to !B
1257 /// %2 = unrealized_conversion_cast %1 : !B to !C
1258 /// %3 = unrealized_conversion_cast %2 : !C to !A
1259 ///
1260 /// In the above example, %0 can be used instead of %3 and all cast ops are
1261 /// folded away.
1264  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1265 
1266 //===----------------------------------------------------------------------===//
1267 // Op Conversion Entry Points
1268 //===----------------------------------------------------------------------===//
1269 
1270 /// Below we define several entry points for operation conversion. It is
1271 /// important to note that the patterns provided to the conversion framework may
1272 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1273 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1274 /// the use of the PatternRewriter.
1275 
1276 /// Apply a partial conversion on the given operations and all nested
1277 /// operations. This method converts as many operations to the target as
1278 /// possible, ignoring operations that failed to legalize. This method only
1279 /// returns failure if there ops explicitly marked as illegal.
1280 LogicalResult
1282  const ConversionTarget &target,
1285 LogicalResult
1289 
1290 /// Apply a complete conversion on the given operations, and all nested
1291 /// operations. This method returns failure if the conversion of any operation
1292 /// fails, or if there are unreachable blocks in any of the regions nested
1293 /// within 'ops'.
1294 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1295  const ConversionTarget &target,
1298 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1301 
1302 /// Apply an analysis conversion on the given operations, and all nested
1303 /// operations. This method analyzes which operations would be successfully
1304 /// converted to the target if a conversion was applied. All operations that
1305 /// were found to be legalizable to the given 'target' are placed within the
1306 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1307 /// applied to the operations on success. This method only returns failure if
1308 /// there are unreachable blocks in any of the regions nested within 'ops'.
1309 LogicalResult
1313 LogicalResult
1317 } // namespace mlir
1318 
1319 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void replaceOpWithMultiple(Operation *op, ArrayRef< RangeT > newValues)
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, RangeT &&newValues)
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:204
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:313
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
typename SourceOp::Adaptor OpAdaptor
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
This struct represents a range of new types or a range of values that remaps an existing signature in...
bool replacedWithValues() const
Return "true" if this input was replaces with one or multiple values.