MLIR  20.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 // Forward declarations.
25 class Attribute;
26 class Block;
27 struct ConversionConfig;
28 class ConversionPatternRewriter;
29 class MLIRContext;
30 class Operation;
31 struct OperationConverter;
32 class Type;
33 class Value;
34 
35 //===----------------------------------------------------------------------===//
36 // Type Conversion
37 //===----------------------------------------------------------------------===//
38 
39 /// Type conversion class. Specific conversions and materializations can be
40 /// registered using addConversion and addMaterialization, respectively.
42 public:
43  virtual ~TypeConverter() = default;
44  TypeConverter() = default;
45  // Copy the registered conversions, but not the caches
47  : conversions(other.conversions),
48  argumentMaterializations(other.argumentMaterializations),
49  sourceMaterializations(other.sourceMaterializations),
50  targetMaterializations(other.targetMaterializations),
51  typeAttributeConversions(other.typeAttributeConversions) {}
53  conversions = other.conversions;
54  argumentMaterializations = other.argumentMaterializations;
55  sourceMaterializations = other.sourceMaterializations;
56  targetMaterializations = other.targetMaterializations;
57  typeAttributeConversions = other.typeAttributeConversions;
58  return *this;
59  }
60 
61  /// This class provides all of the information necessary to convert a type
62  /// signature.
64  public:
65  SignatureConversion(unsigned numOrigInputs)
66  : remappedInputs(numOrigInputs) {}
67 
68  /// This struct represents a range of new types or a single value that
69  /// remaps an existing signature input.
70  struct InputMapping {
71  size_t inputNo, size;
73  };
74 
75  /// Return the argument types for the new signature.
76  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
77 
78  /// Get the input mapping for the given argument.
79  std::optional<InputMapping> getInputMapping(unsigned input) const {
80  return remappedInputs[input];
81  }
82 
83  //===------------------------------------------------------------------===//
84  // Conversion Hooks
85  //===------------------------------------------------------------------===//
86 
87  /// Remap an input of the original signature with a new set of types. The
88  /// new types are appended to the new signature conversion.
89  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
90 
91  /// Append new input types to the signature conversion, this should only be
92  /// used if the new types are not intended to remap an existing input.
93  void addInputs(ArrayRef<Type> types);
94 
95  /// Remap an input of the original signature to another `replacement`
96  /// value. This drops the original argument.
97  void remapInput(unsigned origInputNo, Value replacement);
98 
99  private:
100  /// Remap an input of the original signature with a range of types in the
101  /// new signature.
102  void remapInput(unsigned origInputNo, unsigned newInputNo,
103  unsigned newInputCount = 1);
104 
105  /// The remapping information for each of the original arguments.
106  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
107 
108  /// The set of new argument types.
109  SmallVector<Type, 4> argTypes;
110  };
111 
112  /// The general result of a type attribute conversion callback, allowing
113  /// for early termination. The default constructor creates the na case.
115  public:
116  constexpr AttributeConversionResult() : impl() {}
117  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
118 
120  static AttributeConversionResult na();
122 
123  bool hasResult() const;
124  bool isNa() const;
125  bool isAbort() const;
126 
127  Attribute getResult() const;
128 
129  private:
130  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
131 
132  llvm::PointerIntPair<Attribute, 2> impl;
133  // Note that na is 0 so that we can use PointerIntPair's default
134  // constructor.
135  static constexpr unsigned naTag = 0;
136  static constexpr unsigned resultTag = 1;
137  static constexpr unsigned abortTag = 2;
138  };
139 
140  /// Register a conversion function. A conversion function must be convertible
141  /// to any of the following forms (where `T` is a class derived from `Type`):
142  ///
143  /// * std::optional<Type>(T)
144  /// - This form represents a 1-1 type conversion. It should return nullptr
145  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
146  /// the converter is allowed to try another conversion function to
147  /// perform the conversion.
148  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
149  /// - This form represents a 1-N type conversion. It should return
150  /// `failure` or `std::nullopt` to signify a failed conversion. If the
151  /// new set of types is empty, the type is removed and any usages of the
152  /// existing value are expected to be removed during conversion. If
153  /// `std::nullopt` is returned, the converter is allowed to try another
154  /// conversion function to perform the conversion.
155  ///
156  /// Note: When attempting to convert a type, e.g. via 'convertType', the
157  /// mostly recently added conversions will be invoked first.
158  template <typename FnT, typename T = typename llvm::function_traits<
159  std::decay_t<FnT>>::template arg_t<0>>
160  void addConversion(FnT &&callback) {
161  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
162  }
163 
164  /// All of the following materializations require function objects that are
165  /// convertible to the following form:
166  /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
167  /// where `T` is any subclass of `Type`. This function is responsible for
168  /// creating an operation, using the OpBuilder and Location provided, that
169  /// "casts" a range of values into a single value of the given type `T`. It
170  /// must return a Value of the type `T` on success, an `std::nullopt` if
171  /// it failed but other materialization can be attempted, and `nullptr` on
172  /// unrecoverable failure. Materialization functions must be provided when a
173  /// type conversion may persist after the conversion has finished.
174  ///
175  /// Note: Target materializations may optionally accept an additional Type
176  /// parameter, which is the original type of the SSA value.
177 
178  /// This method registers a materialization that will be called when
179  /// converting (potentially multiple) block arguments that were the result of
180  /// a signature conversion of a single block argument, to a single SSA value
181  /// with the old block argument type.
182  template <typename FnT, typename T = typename llvm::function_traits<
183  std::decay_t<FnT>>::template arg_t<1>>
184  void addArgumentMaterialization(FnT &&callback) {
185  argumentMaterializations.emplace_back(
186  wrapMaterialization<T>(std::forward<FnT>(callback)));
187  }
188 
189  /// This method registers a materialization that will be called when
190  /// converting a legal replacement value back to an illegal source type.
191  /// This is used when some uses of the original, illegal value must persist
192  /// beyond the main conversion.
193  template <typename FnT, typename T = typename llvm::function_traits<
194  std::decay_t<FnT>>::template arg_t<1>>
195  void addSourceMaterialization(FnT &&callback) {
196  sourceMaterializations.emplace_back(
197  wrapMaterialization<T>(std::forward<FnT>(callback)));
198  }
199 
200  /// This method registers a materialization that will be called when
201  /// converting an illegal (source) value to a legal (target) type.
202  ///
203  /// Note: For target materializations, users can optionally take the original
204  /// type. This type may be different from the type of the input. For example,
205  /// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
206  /// (type "t1") with "v2" (type "t2"). Then a different conversion pattern
207  /// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
208  /// that "P2" determines that the legalized type of "t1" is "t3", which may
209  /// be different from "t2". In this example, the target materialization
210  /// will be invoked with: outputType = "t3", inputs = "v2",
211  // originalType = "t1". Note that the original type "t1" cannot be recovered
212  /// from just "t3" and "v2"; that's why the originalType parameter exists.
213  template <typename FnT, typename T = typename llvm::function_traits<
214  std::decay_t<FnT>>::template arg_t<1>>
215  void addTargetMaterialization(FnT &&callback) {
216  targetMaterializations.emplace_back(
217  wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
218  }
219 
220  /// Register a conversion function for attributes within types. Type
221  /// converters may call this function in order to allow hoking into the
222  /// translation of attributes that exist within types. For example, a type
223  /// converter for the `memref` type could use these conversions to convert
224  /// memory spaces or layouts in an extensible way.
225  ///
226  /// The conversion functions take a non-null Type or subclass of Type and a
227  /// non-null Attribute (or subclass of Attribute), and returns a
228  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
229  /// which may be `nullptr`, representing the conversion's success,
230  /// `AttributeConversionResult::na()` (the default empty value), indicating
231  /// that the conversion function did not apply and that further conversion
232  /// functions should be checked, or `AttributeConversionResult::abort()`
233  /// indicating that the conversion process should be aborted.
234  ///
235  /// Registered conversion functions are callled in the reverse of the order in
236  /// which they were registered.
237  template <
238  typename FnT,
239  typename T =
240  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
241  typename A =
242  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
243  void addTypeAttributeConversion(FnT &&callback) {
244  registerTypeAttributeConversion(
245  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
246  }
247 
248  /// Convert the given type. This function should return failure if no valid
249  /// conversion exists, success otherwise. If the new set of types is empty,
250  /// the type is removed and any usages of the existing value are expected to
251  /// be removed during conversion.
252  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
253 
254  /// This hook simplifies defining 1-1 type conversions. This function returns
255  /// the type to convert to on success, and a null type on failure.
256  Type convertType(Type t) const;
257 
258  /// Attempts a 1-1 type conversion, expecting the result type to be
259  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
260  /// and a null type on conversion or cast failure.
261  template <typename TargetType>
262  TargetType convertType(Type t) const {
263  return dyn_cast_or_null<TargetType>(convertType(t));
264  }
265 
266  /// Convert the given set of types, filling 'results' as necessary. This
267  /// returns failure if the conversion of any of the types fails, success
268  /// otherwise.
269  LogicalResult convertTypes(TypeRange types,
270  SmallVectorImpl<Type> &results) const;
271 
272  /// Return true if the given type is legal for this type converter, i.e. the
273  /// type converts to itself.
274  bool isLegal(Type type) const;
275 
276  /// Return true if all of the given types are legal for this type converter.
277  template <typename RangeT>
278  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
279  !std::is_convertible<RangeT, Operation *>::value,
280  bool>
281  isLegal(RangeT &&range) const {
282  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
283  }
284  /// Return true if the given operation has legal operand and result types.
285  bool isLegal(Operation *op) const;
286 
287  /// Return true if the types of block arguments within the region are legal.
288  bool isLegal(Region *region) const;
289 
290  /// Return true if the inputs and outputs of the given function type are
291  /// legal.
292  bool isSignatureLegal(FunctionType ty) const;
293 
294  /// This method allows for converting a specific argument of a signature. It
295  /// takes as inputs the original argument input number, type.
296  /// On success, it populates 'result' with any new mappings.
297  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
298  SignatureConversion &result) const;
299  LogicalResult convertSignatureArgs(TypeRange types,
300  SignatureConversion &result,
301  unsigned origInputOffset = 0) const;
302 
303  /// This function converts the type signature of the given block, by invoking
304  /// 'convertSignatureArg' for each argument. This function should return a
305  /// valid conversion for the signature on success, std::nullopt otherwise.
306  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
307 
308  /// Materialize a conversion from a set of types into one result type by
309  /// generating a cast sequence of some kind. See the respective
310  /// `add*Materialization` for more information on the context for these
311  /// methods.
313  Type resultType, ValueRange inputs) const;
315  Type resultType, ValueRange inputs) const;
317  Type resultType, ValueRange inputs,
318  Type originalType = {}) const;
319 
320  /// Convert an attribute present `attr` from within the type `type` using
321  /// the registered conversion functions. If no applicable conversion has been
322  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
323  /// is a valid return value for this function.
324  std::optional<Attribute> convertTypeAttribute(Type type,
325  Attribute attr) const;
326 
327 private:
328  /// The signature of the callback used to convert a type. If the new set of
329  /// types is empty, the type is removed and any usages of the existing value
330  /// are expected to be removed during conversion.
331  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
332  Type, SmallVectorImpl<Type> &)>;
333 
334  /// The signature of the callback used to materialize a source/argument
335  /// conversion.
336  ///
337  /// Arguments: builder, result type, inputs, location
338  using MaterializationCallbackFn = std::function<std::optional<Value>(
339  OpBuilder &, Type, ValueRange, Location)>;
340 
341  /// The signature of the callback used to materialize a target conversion.
342  ///
343  /// Arguments: builder, result type, inputs, location, original type
344  using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
345  OpBuilder &, Type, ValueRange, Location, Type)>;
346 
347  /// The signature of the callback used to convert a type attribute.
348  using TypeAttributeConversionCallbackFn =
349  std::function<AttributeConversionResult(Type, Attribute)>;
350 
351  /// Generate a wrapper for the given callback. This allows for accepting
352  /// different callback forms, that all compose into a single version.
353  /// With callback of form: `std::optional<Type>(T)`
354  template <typename T, typename FnT>
355  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
356  wrapCallback(FnT &&callback) const {
357  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
358  T type, SmallVectorImpl<Type> &results) {
359  if (std::optional<Type> resultOpt = callback(type)) {
360  bool wasSuccess = static_cast<bool>(*resultOpt);
361  if (wasSuccess)
362  results.push_back(*resultOpt);
363  return std::optional<LogicalResult>(success(wasSuccess));
364  }
365  return std::optional<LogicalResult>();
366  });
367  }
368  /// With callback of form: `std::optional<LogicalResult>(
369  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
370  template <typename T, typename FnT>
371  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
372  ConversionCallbackFn>
373  wrapCallback(FnT &&callback) const {
374  return [callback = std::forward<FnT>(callback)](
375  Type type,
376  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
377  T derivedType = dyn_cast<T>(type);
378  if (!derivedType)
379  return std::nullopt;
380  return callback(derivedType, results);
381  };
382  }
383 
384  /// Register a type conversion.
385  void registerConversion(ConversionCallbackFn callback) {
386  conversions.emplace_back(std::move(callback));
387  cachedDirectConversions.clear();
388  cachedMultiConversions.clear();
389  }
390 
391  /// Generate a wrapper for the given argument/source materialization
392  /// callback. The callback may take any subclass of `Type` and the
393  /// wrapper will check for the target type to be of the expected class
394  /// before calling the callback.
395  template <typename T, typename FnT>
396  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
397  return [callback = std::forward<FnT>(callback)](
398  OpBuilder &builder, Type resultType, ValueRange inputs,
399  Location loc) -> std::optional<Value> {
400  if (T derivedType = dyn_cast<T>(resultType))
401  return callback(builder, derivedType, inputs, loc);
402  return std::nullopt;
403  };
404  }
405 
406  /// Generate a wrapper for the given target materialization callback.
407  /// The callback may take any subclass of `Type` and the wrapper will check
408  /// for the target type to be of the expected class before calling the
409  /// callback.
410  ///
411  /// With callback of form:
412  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
413  template <typename T, typename FnT>
414  std::enable_if_t<
415  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
416  TargetMaterializationCallbackFn>
417  wrapTargetMaterialization(FnT &&callback) const {
418  return [callback = std::forward<FnT>(callback)](
419  OpBuilder &builder, Type resultType, ValueRange inputs,
420  Location loc, Type originalType) -> std::optional<Value> {
421  if (T derivedType = dyn_cast<T>(resultType))
422  return callback(builder, derivedType, inputs, loc, originalType);
423  return std::nullopt;
424  };
425  }
426  /// With callback of form:
427  /// `Value(OpBuilder &, T, ValueRange, Location)`
428  template <typename T, typename FnT>
429  std::enable_if_t<
430  std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
431  TargetMaterializationCallbackFn>
432  wrapTargetMaterialization(FnT &&callback) const {
433  return wrapTargetMaterialization<T>(
434  [callback = std::forward<FnT>(callback)](
435  OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
436  Type originalType) -> std::optional<Value> {
437  return callback(builder, resultType, inputs, loc);
438  });
439  }
440 
441  /// Generate a wrapper for the given memory space conversion callback. The
442  /// callback may take any subclass of `Attribute` and the wrapper will check
443  /// for the target attribute to be of the expected class before calling the
444  /// callback.
445  template <typename T, typename A, typename FnT>
446  TypeAttributeConversionCallbackFn
447  wrapTypeAttributeConversion(FnT &&callback) const {
448  return [callback = std::forward<FnT>(callback)](
449  Type type, Attribute attr) -> AttributeConversionResult {
450  if (T derivedType = dyn_cast<T>(type)) {
451  if (A derivedAttr = dyn_cast_or_null<A>(attr))
452  return callback(derivedType, derivedAttr);
453  }
455  };
456  }
457 
458  /// Register a memory space conversion, clearing caches.
459  void
460  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
461  typeAttributeConversions.emplace_back(std::move(callback));
462  // Clear type conversions in case a memory space is lingering inside.
463  cachedDirectConversions.clear();
464  cachedMultiConversions.clear();
465  }
466 
467  /// The set of registered conversion functions.
468  SmallVector<ConversionCallbackFn, 4> conversions;
469 
470  /// The list of registered materialization functions.
471  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
472  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
473  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
474 
475  /// The list of registered type attribute conversion functions.
476  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
477 
478  /// A set of cached conversions to avoid recomputing in the common case.
479  /// Direct 1-1 conversions are the most common, so this cache stores the
480  /// successful 1-1 conversions as well as all failed conversions.
481  mutable DenseMap<Type, Type> cachedDirectConversions;
482  /// This cache stores the successful 1->N conversions, where N != 1.
483  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
484  /// A mutex used for cache access
485  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
486 };
487 
488 //===----------------------------------------------------------------------===//
489 // Conversion Patterns
490 //===----------------------------------------------------------------------===//
491 
492 /// Base class for the conversion patterns. This pattern class enables type
493 /// conversions, and other uses specific to the conversion framework. As such,
494 /// patterns of this type can only be used with the 'apply*' methods below.
496 public:
497  /// Hook for derived classes to implement rewriting. `op` is the (first)
498  /// operation matched by the pattern, `operands` is a list of the rewritten
499  /// operand values that are passed to `op`, `rewriter` can be used to emit the
500  /// new operations. This function should not fail. If some specific cases of
501  /// the operation are not supported, these cases should not be matched.
502  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
503  ConversionPatternRewriter &rewriter) const {
504  llvm_unreachable("unimplemented rewrite");
505  }
506 
507  /// Hook for derived classes to implement combined matching and rewriting.
508  virtual LogicalResult
510  ConversionPatternRewriter &rewriter) const {
511  if (failed(match(op)))
512  return failure();
513  rewrite(op, operands, rewriter);
514  return success();
515  }
516 
517  /// Attempt to match and rewrite the IR root at the specified operation.
518  LogicalResult matchAndRewrite(Operation *op,
519  PatternRewriter &rewriter) const final;
520 
521  /// Return the type converter held by this pattern, or nullptr if the pattern
522  /// does not require type conversion.
523  const TypeConverter *getTypeConverter() const { return typeConverter; }
524 
525  template <typename ConverterTy>
526  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
527  const ConverterTy *>
529  return static_cast<const ConverterTy *>(typeConverter);
530  }
531 
532 protected:
533  /// See `RewritePattern::RewritePattern` for information on the other
534  /// available constructors.
535  using RewritePattern::RewritePattern;
536  /// Construct a conversion pattern with the given converter, and forward the
537  /// remaining arguments to RewritePattern.
538  template <typename... Args>
540  : RewritePattern(std::forward<Args>(args)...),
542 
543 protected:
544  /// An optional type converter for use by this pattern.
545  const TypeConverter *typeConverter = nullptr;
546 
547 private:
549 };
550 
551 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
552 /// matching and rewriting against an instance of a derived operation class as
553 /// opposed to a raw Operation.
554 template <typename SourceOp>
556 public:
557  using OpAdaptor = typename SourceOp::Adaptor;
558 
560  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
562  PatternBenefit benefit = 1)
563  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
564  context) {}
565 
566  /// Wrappers around the ConversionPattern methods that pass the derived op
567  /// type.
568  LogicalResult match(Operation *op) const final {
569  return match(cast<SourceOp>(op));
570  }
571  void rewrite(Operation *op, ArrayRef<Value> operands,
572  ConversionPatternRewriter &rewriter) const final {
573  auto sourceOp = cast<SourceOp>(op);
574  rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
575  }
576  LogicalResult
578  ConversionPatternRewriter &rewriter) const final {
579  auto sourceOp = cast<SourceOp>(op);
580  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
581  }
582 
583  /// Rewrite and Match methods that operate on the SourceOp type. These must be
584  /// overridden by the derived pattern class.
585  virtual LogicalResult match(SourceOp op) const {
586  llvm_unreachable("must override match or matchAndRewrite");
587  }
588  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
589  ConversionPatternRewriter &rewriter) const {
590  llvm_unreachable("must override matchAndRewrite or a rewrite method");
591  }
592  virtual LogicalResult
593  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
594  ConversionPatternRewriter &rewriter) const {
595  if (failed(match(op)))
596  return failure();
597  rewrite(op, adaptor, rewriter);
598  return success();
599  }
600 
601 private:
603 };
604 
605 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
606 /// allows for matching and rewriting against an instance of an OpInterface
607 /// class as opposed to a raw Operation.
608 template <typename SourceOp>
610 public:
613  SourceOp::getInterfaceID(), benefit, context) {}
615  MLIRContext *context, PatternBenefit benefit = 1)
617  SourceOp::getInterfaceID(), benefit, context) {}
618 
619  /// Wrappers around the ConversionPattern methods that pass the derived op
620  /// type.
621  void rewrite(Operation *op, ArrayRef<Value> operands,
622  ConversionPatternRewriter &rewriter) const final {
623  rewrite(cast<SourceOp>(op), operands, rewriter);
624  }
625  LogicalResult
627  ConversionPatternRewriter &rewriter) const final {
628  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
629  }
630 
631  /// Rewrite and Match methods that operate on the SourceOp type. These must be
632  /// overridden by the derived pattern class.
633  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
634  ConversionPatternRewriter &rewriter) const {
635  llvm_unreachable("must override matchAndRewrite or a rewrite method");
636  }
637  virtual LogicalResult
638  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
639  ConversionPatternRewriter &rewriter) const {
640  if (failed(match(op)))
641  return failure();
642  rewrite(op, operands, rewriter);
643  return success();
644  }
645 
646 private:
648 };
649 
650 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
651 /// for matching and rewriting against instances of an operation that possess a
652 /// given trait.
653 template <template <typename> class TraitType>
655 public:
658  TypeID::get<TraitType>(), benefit, context) {}
660  MLIRContext *context, PatternBenefit benefit = 1)
662  TypeID::get<TraitType>(), benefit, context) {}
663 };
664 
665 /// Generic utility to convert op result types according to type converter
666 /// without knowing exact op type.
667 /// Clones existing op with new result types and returns it.
668 FailureOr<Operation *>
669 convertOpResultTypes(Operation *op, ValueRange operands,
670  const TypeConverter &converter,
671  ConversionPatternRewriter &rewriter);
672 
673 /// Add a pattern to the given pattern list to convert the signature of a
674 /// FunctionOpInterface op with the given type converter. This only supports
675 /// ops which use FunctionType to represent their type.
677  StringRef functionLikeOpName, RewritePatternSet &patterns,
678  const TypeConverter &converter);
679 
680 template <typename FuncOpT>
682  RewritePatternSet &patterns, const TypeConverter &converter) {
683  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
684  patterns, converter);
685 }
686 
688  RewritePatternSet &patterns, const TypeConverter &converter);
689 
690 //===----------------------------------------------------------------------===//
691 // Conversion PatternRewriter
692 //===----------------------------------------------------------------------===//
693 
694 namespace detail {
695 struct ConversionPatternRewriterImpl;
696 } // namespace detail
697 
698 /// This class implements a pattern rewriter for use with ConversionPatterns. It
699 /// extends the base PatternRewriter and provides special conversion specific
700 /// hooks.
702 public:
704 
705  /// Apply a signature conversion to given block. This replaces the block with
706  /// a new block containing the updated signature. The operations of the given
707  /// block are inlined into the newly-created block, which is returned.
708  ///
709  /// If no block argument types are changing, the original block will be
710  /// left in place and returned.
711  ///
712  /// A signature converison must be provided. (Type converters can construct
713  /// a signature conversion with `convertBlockSignature`.)
714  ///
715  /// Optionally, a type converter can be provided to build materializations.
716  /// Note: If no type converter was provided or the type converter does not
717  /// specify any suitable argument/target materialization rules, the dialect
718  /// conversion may fail to legalize unresolved materializations.
719  Block *
722  const TypeConverter *converter = nullptr);
723 
724  /// Apply a signature conversion to each block in the given region. This
725  /// replaces each block with a new block containing the updated signature. If
726  /// an updated signature would match the current signature, the respective
727  /// block is left in place as is. (See `applySignatureConversion` for
728  /// details.) The new entry block of the region is returned.
729  ///
730  /// SignatureConversions are computed with the specified type converter.
731  /// This function returns "failure" if the type converter failed to compute
732  /// a SignatureConversion for at least one block.
733  ///
734  /// Optionally, a special SignatureConversion can be specified for the entry
735  /// block. This is because the types of the entry block arguments are often
736  /// tied semantically to the operation.
737  FailureOr<Block *> convertRegionTypes(
738  Region *region, const TypeConverter &converter,
739  TypeConverter::SignatureConversion *entryConversion = nullptr);
740 
741  /// Replace all the uses of the block argument `from` with value `to`.
743 
744  /// Return the converted value of 'key' with a type defined by the type
745  /// converter of the currently executing pattern. Return nullptr in the case
746  /// of failure, the remapped value otherwise.
748 
749  /// Return the converted values that replace 'keys' with types defined by the
750  /// type converter of the currently executing pattern. Returns failure if the
751  /// remap failed, success otherwise.
752  LogicalResult getRemappedValues(ValueRange keys,
753  SmallVectorImpl<Value> &results);
754 
755  //===--------------------------------------------------------------------===//
756  // PatternRewriter Hooks
757  //===--------------------------------------------------------------------===//
758 
759  /// Indicate that the conversion rewriter can recover from rewrite failure.
760  /// Recovery is supported via rollback, allowing for continued processing of
761  /// patterns even if a failure is encountered during the rewrite step.
762  bool canRecoverFromRewriteFailure() const override { return true; }
763 
764  /// PatternRewriter hook for replacing an operation.
765  void replaceOp(Operation *op, ValueRange newValues) override;
766 
767  /// PatternRewriter hook for replacing an operation.
768  void replaceOp(Operation *op, Operation *newOp) override;
769 
770  /// PatternRewriter hook for erasing a dead operation. The uses of this
771  /// operation *must* be made dead by the end of the conversion process,
772  /// otherwise an assert will be issued.
773  void eraseOp(Operation *op) override;
774 
775  /// PatternRewriter hook for erase all operations in a block. This is not yet
776  /// implemented for dialect conversion.
777  void eraseBlock(Block *block) override;
778 
779  /// PatternRewriter hook for inlining the ops of a block into another block.
780  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
781  ValueRange argValues = std::nullopt) override;
783 
784  /// PatternRewriter hook for updating the given operation in-place.
785  /// Note: These methods only track updates to the given operation itself,
786  /// and not nested regions. Updates to regions will still require notification
787  /// through other more specific hooks above.
788  void startOpModification(Operation *op) override;
789 
790  /// PatternRewriter hook for updating the given operation in-place.
791  void finalizeOpModification(Operation *op) override;
792 
793  /// PatternRewriter hook for updating the given operation in-place.
794  void cancelOpModification(Operation *op) override;
795 
796  /// Return a reference to the internal implementation.
798 
799 private:
800  // Allow OperationConverter to construct new rewriters.
801  friend struct OperationConverter;
802 
803  /// Conversion pattern rewriters must not be used outside of dialect
804  /// conversions. They apply some IR rewrites in a delayed fashion and could
805  /// bring the IR into an inconsistent state when used standalone.
807  const ConversionConfig &config);
808 
809  // Hide unsupported pattern rewriter API.
811 
812  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
813 };
814 
815 //===----------------------------------------------------------------------===//
816 // ConversionTarget
817 //===----------------------------------------------------------------------===//
818 
819 /// This class describes a specific conversion target.
821 public:
822  /// This enumeration corresponds to the specific action to take when
823  /// considering an operation legal for this conversion target.
824  enum class LegalizationAction {
825  /// The target supports this operation.
826  Legal,
827 
828  /// This operation has dynamic legalization constraints that must be checked
829  /// by the target.
830  Dynamic,
831 
832  /// The target explicitly does not support this operation.
833  Illegal,
834  };
835 
836  /// A structure containing additional information describing a specific legal
837  /// operation instance.
838  struct LegalOpDetails {
839  /// A flag that indicates if this operation is 'recursively' legal. This
840  /// means that if an operation is legal, either statically or dynamically,
841  /// all of the operations nested within are also considered legal.
842  bool isRecursivelyLegal = false;
843  };
844 
845  /// The signature of the callback used to determine if an operation is
846  /// dynamically legal on the target.
848  std::function<std::optional<bool>(Operation *)>;
849 
850  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
851  virtual ~ConversionTarget() = default;
852 
853  //===--------------------------------------------------------------------===//
854  // Legality Registration
855  //===--------------------------------------------------------------------===//
856 
857  /// Register a legality action for the given operation.
859  template <typename OpT>
861  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
862  }
863 
864  /// Register the given operations as legal.
867  }
868  template <typename OpT>
869  void addLegalOp() {
870  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
871  }
872  template <typename OpT, typename OpT2, typename... OpTs>
873  void addLegalOp() {
874  addLegalOp<OpT>();
875  addLegalOp<OpT2, OpTs...>();
876  }
877 
878  /// Register the given operation as dynamically legal and set the dynamic
879  /// legalization callback to the one provided.
881  const DynamicLegalityCallbackFn &callback) {
883  setLegalityCallback(op, callback);
884  }
885  template <typename OpT>
887  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
888  callback);
889  }
890  template <typename OpT, typename OpT2, typename... OpTs>
892  addDynamicallyLegalOp<OpT>(callback);
893  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
894  }
895  template <typename OpT, class Callable>
896  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
897  addDynamicallyLegalOp(Callable &&callback) {
898  addDynamicallyLegalOp<OpT>(
899  [=](Operation *op) { return callback(cast<OpT>(op)); });
900  }
901 
902  /// Register the given operation as illegal, i.e. this operation is known to
903  /// not be supported by this target.
906  }
907  template <typename OpT>
908  void addIllegalOp() {
909  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
910  }
911  template <typename OpT, typename OpT2, typename... OpTs>
912  void addIllegalOp() {
913  addIllegalOp<OpT>();
914  addIllegalOp<OpT2, OpTs...>();
915  }
916 
917  /// Mark an operation, that *must* have either been set as `Legal` or
918  /// `DynamicallyLegal`, as being recursively legal. This means that in
919  /// addition to the operation itself, all of the operations nested within are
920  /// also considered legal. An optional dynamic legality callback may be
921  /// provided to mark subsets of legal instances as recursively legal.
923  const DynamicLegalityCallbackFn &callback);
924  template <typename OpT>
926  OperationName opName(OpT::getOperationName(), &ctx);
927  markOpRecursivelyLegal(opName, callback);
928  }
929  template <typename OpT, typename OpT2, typename... OpTs>
931  markOpRecursivelyLegal<OpT>(callback);
932  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
933  }
934  template <typename OpT, class Callable>
935  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
936  markOpRecursivelyLegal(Callable &&callback) {
937  markOpRecursivelyLegal<OpT>(
938  [=](Operation *op) { return callback(cast<OpT>(op)); });
939  }
940 
941  /// Register a legality action for the given dialects.
942  void setDialectAction(ArrayRef<StringRef> dialectNames,
943  LegalizationAction action);
944 
945  /// Register the operations of the given dialects as legal.
946  template <typename... Names>
947  void addLegalDialect(StringRef name, Names... names) {
948  SmallVector<StringRef, 2> dialectNames({name, names...});
950  }
951  template <typename... Args>
953  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
955  }
956 
957  /// Register the operations of the given dialects as dynamically legal, i.e.
958  /// requiring custom handling by the callback.
959  template <typename... Names>
961  StringRef name, Names... names) {
962  SmallVector<StringRef, 2> dialectNames({name, names...});
964  setLegalityCallback(dialectNames, callback);
965  }
966  template <typename... Args>
968  addDynamicallyLegalDialect(std::move(callback),
969  Args::getDialectNamespace()...);
970  }
971 
972  /// Register unknown operations as dynamically legal. For operations(and
973  /// dialects) that do not have a set legalization action, treat them as
974  /// dynamically legal and invoke the given callback.
976  setLegalityCallback(fn);
977  }
978 
979  /// Register the operations of the given dialects as illegal, i.e.
980  /// operations of this dialect are not supported by the target.
981  template <typename... Names>
982  void addIllegalDialect(StringRef name, Names... names) {
983  SmallVector<StringRef, 2> dialectNames({name, names...});
985  }
986  template <typename... Args>
988  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
990  }
991 
992  //===--------------------------------------------------------------------===//
993  // Legality Querying
994  //===--------------------------------------------------------------------===//
995 
996  /// Get the legality action for the given operation.
997  std::optional<LegalizationAction> getOpAction(OperationName op) const;
998 
999  /// If the given operation instance is legal on this target, a structure
1000  /// containing legality information is returned. If the operation is not
1001  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1002  /// legality wasn't registered by user or dynamic legality callbacks returned
1003  /// None.
1004  ///
1005  /// Note: Legality is actually a 4-state: Legal(recursive=true),
1006  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1007  /// either as Legal or Illegal depending on context.
1008  std::optional<LegalOpDetails> isLegal(Operation *op) const;
1009 
1010  /// Returns true is operation instance is illegal on this target. Returns
1011  /// false if operation is legal, operation legality wasn't registered by user
1012  /// or dynamic legality callbacks returned None.
1013  bool isIllegal(Operation *op) const;
1014 
1015 private:
1016  /// Set the dynamic legality callback for the given operation.
1017  void setLegalityCallback(OperationName name,
1018  const DynamicLegalityCallbackFn &callback);
1019 
1020  /// Set the dynamic legality callback for the given dialects.
1021  void setLegalityCallback(ArrayRef<StringRef> dialects,
1022  const DynamicLegalityCallbackFn &callback);
1023 
1024  /// Set the dynamic legality callback for the unknown ops.
1025  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1026 
1027  /// The set of information that configures the legalization of an operation.
1028  struct LegalizationInfo {
1029  /// The legality action this operation was given.
1031 
1032  /// If some legal instances of this operation may also be recursively legal.
1033  bool isRecursivelyLegal = false;
1034 
1035  /// The legality callback if this operation is dynamically legal.
1036  DynamicLegalityCallbackFn legalityFn;
1037  };
1038 
1039  /// Get the legalization information for the given operation.
1040  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1041 
1042  /// A deterministic mapping of operation name and its respective legality
1043  /// information.
1044  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1045 
1046  /// A set of legality callbacks for given operation names that are used to
1047  /// check if an operation instance is recursively legal.
1048  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1049 
1050  /// A deterministic mapping of dialect name to the specific legality action to
1051  /// take.
1052  llvm::StringMap<LegalizationAction> legalDialects;
1053 
1054  /// A set of dynamic legality callbacks for given dialect names.
1055  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1056 
1057  /// An optional legality callback for unknown operations.
1058  DynamicLegalityCallbackFn unknownLegalityFn;
1059 
1060  /// The current context this target applies to.
1061  MLIRContext &ctx;
1062 };
1063 
1064 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1065 //===----------------------------------------------------------------------===//
1066 // PDL Configuration
1067 //===----------------------------------------------------------------------===//
1068 
1069 /// A PDL configuration that is used to supported dialect conversion
1070 /// functionality.
1072  : public PDLPatternConfigBase<PDLConversionConfig> {
1073 public:
1074  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1075  ~PDLConversionConfig() final = default;
1076 
1077  /// Return the type converter used by this configuration, which may be nullptr
1078  /// if no type conversions are expected.
1079  const TypeConverter *getTypeConverter() const { return converter; }
1080 
1081  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1082  /// pattern.
1083  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1084  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1085 
1086 private:
1087  /// An optional type converter to use for the pattern.
1088  const TypeConverter *converter;
1089 };
1090 
1091 /// Register the dialect conversion PDL functions with the given pattern set.
1092 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1093 
1094 #else
1095 
1096 // Stubs for when PDL in rewriting is not enabled.
1097 
1098 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1099 
1100 class PDLConversionConfig final {
1101 public:
1102  PDLConversionConfig(const TypeConverter * /*converter*/) {}
1103 };
1104 
1105 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1106 
1107 //===----------------------------------------------------------------------===//
1108 // ConversionConfig
1109 //===----------------------------------------------------------------------===//
1110 
1111 /// Dialect conversion configuration.
1113  /// An optional callback used to notify about match failure diagnostics during
1114  /// the conversion. Diagnostics reported to this callback may only be
1115  /// available in debug mode.
1117 
1118  /// Partial conversion only. All operations that are found not to be
1119  /// legalizable are placed in this set. (Note that if there is an op
1120  /// explicitly marked as illegal, the conversion terminates and the set will
1121  /// not necessarily be complete.)
1123 
1124  /// Analysis conversion only. All operations that are found to be legalizable
1125  /// are placed in this set. Note that no actual rewrites are applied to the
1126  /// IR during an analysis conversion and only pre-existing operations are
1127  /// added to the set.
1129 
1130  /// An optional listener that is notified about all IR modifications in case
1131  /// dialect conversion succeeds. If the dialect conversion fails and no IR
1132  /// modifications are visible (i.e., they were all rolled back), or if the
1133  /// dialect conversion is an "analysis conversion", no notifications are
1134  /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1135  ///
1136  /// Note: Notifications are sent in a delayed fashion, when the dialect
1137  /// conversion is guaranteed to succeed. At that point, some IR modifications
1138  /// may already have been materialized. Consequently, operations/blocks that
1139  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1140  /// guaranteed to be valid pointers and accessing op names is allowed. But
1141  /// there are no guarantees about the state of ops/blocks at the time that a
1142  /// callback is triggered.)
1143  ///
1144  /// Example: Consider a dialect conversion a new op ("test.foo") is created
1145  /// and inserted, and later moved to another block. (Moving ops also triggers
1146  /// "notifyOperationInserted".)
1147  ///
1148  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1149  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1150  ///
1151  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1152  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1153  /// immediately performed by the dialect conversion (and rolled back upon
1154  /// failure).
1155  //
1156  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1157  // callback, the previous region/block is provided to the callback, but not
1158  // the iterator pointing to the exact location within the region/block. That
1159  // is because these notifications are sent with a delay (after the IR has
1160  // already been modified) and iterators into past IR state cannot be
1161  // represented at the moment.
1163 
1164  /// If set to "true", the dialect conversion attempts to build source/target/
1165  /// argument materializations through the type converter API in lieu of
1166  /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1167  /// at least one materialization could not be built.
1168  ///
1169  /// If set to "false", the dialect conversion does not build any custom
1170  /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1171  /// ops to ensure that the resulting IR is valid.
1173 };
1174 
1175 //===----------------------------------------------------------------------===//
1176 // Reconcile Unrealized Casts
1177 //===----------------------------------------------------------------------===//
1178 
1179 /// Try to reconcile all given UnrealizedConversionCastOps and store the
1180 /// left-over ops in `remainingCastOps` (if provided).
1181 ///
1182 /// This function processes cast ops in a worklist-driven fashion. For each
1183 /// cast op, if the chain of input casts eventually reaches a cast op where the
1184 /// input types match the output types of the matched op, replace the matched
1185 /// op with the inputs.
1186 ///
1187 /// Example:
1188 /// %1 = unrealized_conversion_cast %0 : !A to !B
1189 /// %2 = unrealized_conversion_cast %1 : !B to !C
1190 /// %3 = unrealized_conversion_cast %2 : !C to !A
1191 ///
1192 /// In the above example, %0 can be used instead of %3 and all cast ops are
1193 /// folded away.
1196  SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1197 
1198 //===----------------------------------------------------------------------===//
1199 // Op Conversion Entry Points
1200 //===----------------------------------------------------------------------===//
1201 
1202 /// Below we define several entry points for operation conversion. It is
1203 /// important to note that the patterns provided to the conversion framework may
1204 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1205 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1206 /// the use of the PatternRewriter.
1207 
1208 /// Apply a partial conversion on the given operations and all nested
1209 /// operations. This method converts as many operations to the target as
1210 /// possible, ignoring operations that failed to legalize. This method only
1211 /// returns failure if there ops explicitly marked as illegal.
1212 LogicalResult
1214  const ConversionTarget &target,
1215  const FrozenRewritePatternSet &patterns,
1216  ConversionConfig config = ConversionConfig());
1217 LogicalResult
1219  const FrozenRewritePatternSet &patterns,
1220  ConversionConfig config = ConversionConfig());
1221 
1222 /// Apply a complete conversion on the given operations, and all nested
1223 /// operations. This method returns failure if the conversion of any operation
1224 /// fails, or if there are unreachable blocks in any of the regions nested
1225 /// within 'ops'.
1226 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1227  const ConversionTarget &target,
1228  const FrozenRewritePatternSet &patterns,
1229  ConversionConfig config = ConversionConfig());
1230 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1231  const FrozenRewritePatternSet &patterns,
1232  ConversionConfig config = ConversionConfig());
1233 
1234 /// Apply an analysis conversion on the given operations, and all nested
1235 /// operations. This method analyzes which operations would be successfully
1236 /// converted to the target if a conversion was applied. All operations that
1237 /// were found to be legalizable to the given 'target' are placed within the
1238 /// provided 'config.legalizableOps' set; note that no actual rewrites are
1239 /// applied to the operations on success. This method only returns failure if
1240 /// there are unreachable blocks in any of the regions nested within 'ops'.
1241 LogicalResult
1243  const FrozenRewritePatternSet &patterns,
1244  ConversionConfig config = ConversionConfig());
1245 LogicalResult
1247  const FrozenRewritePatternSet &patterns,
1248  ConversionConfig config = ConversionConfig());
1249 } // namespace mlir
1250 
1251 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:324
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:169
This struct represents a range of new types or a single value that remaps an existing signature input...