MLIR  18.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/StringMap.h"
19 #include <type_traits>
20 
21 namespace mlir {
22 
23 // Forward declarations.
24 class Attribute;
25 class Block;
26 class ConversionPatternRewriter;
27 class MLIRContext;
28 class Operation;
29 class Type;
30 class Value;
31 
32 //===----------------------------------------------------------------------===//
33 // Type Conversion
34 //===----------------------------------------------------------------------===//
35 
36 /// Type conversion class. Specific conversions and materializations can be
37 /// registered using addConversion and addMaterialization, respectively.
39 public:
40  virtual ~TypeConverter() = default;
41  TypeConverter() = default;
42  // Copy the registered conversions, but not the caches
44  : conversions(other.conversions),
45  argumentMaterializations(other.argumentMaterializations),
46  sourceMaterializations(other.sourceMaterializations),
47  targetMaterializations(other.targetMaterializations),
48  typeAttributeConversions(other.typeAttributeConversions) {}
50  conversions = other.conversions;
51  argumentMaterializations = other.argumentMaterializations;
52  sourceMaterializations = other.sourceMaterializations;
53  targetMaterializations = other.targetMaterializations;
54  typeAttributeConversions = other.typeAttributeConversions;
55  return *this;
56  }
57 
58  /// This class provides all of the information necessary to convert a type
59  /// signature.
61  public:
62  SignatureConversion(unsigned numOrigInputs)
63  : remappedInputs(numOrigInputs) {}
64 
65  /// This struct represents a range of new types or a single value that
66  /// remaps an existing signature input.
67  struct InputMapping {
68  size_t inputNo, size;
70  };
71 
72  /// Return the argument types for the new signature.
73  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
74 
75  /// Get the input mapping for the given argument.
76  std::optional<InputMapping> getInputMapping(unsigned input) const {
77  return remappedInputs[input];
78  }
79 
80  //===------------------------------------------------------------------===//
81  // Conversion Hooks
82  //===------------------------------------------------------------------===//
83 
84  /// Remap an input of the original signature with a new set of types. The
85  /// new types are appended to the new signature conversion.
86  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
87 
88  /// Append new input types to the signature conversion, this should only be
89  /// used if the new types are not intended to remap an existing input.
90  void addInputs(ArrayRef<Type> types);
91 
92  /// Remap an input of the original signature to another `replacement`
93  /// value. This drops the original argument.
94  void remapInput(unsigned origInputNo, Value replacement);
95 
96  private:
97  /// Remap an input of the original signature with a range of types in the
98  /// new signature.
99  void remapInput(unsigned origInputNo, unsigned newInputNo,
100  unsigned newInputCount = 1);
101 
102  /// The remapping information for each of the original arguments.
103  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
104 
105  /// The set of new argument types.
106  SmallVector<Type, 4> argTypes;
107  };
108 
109  /// The general result of a type attribute conversion callback, allowing
110  /// for early termination. The default constructor creates the na case.
112  public:
113  constexpr AttributeConversionResult() : impl() {}
114  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
115 
117  static AttributeConversionResult na();
119 
120  bool hasResult() const;
121  bool isNa() const;
122  bool isAbort() const;
123 
124  Attribute getResult() const;
125 
126  private:
127  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
128 
129  llvm::PointerIntPair<Attribute, 2> impl;
130  // Note that na is 0 so that we can use PointerIntPair's default
131  // constructor.
132  static constexpr unsigned naTag = 0;
133  static constexpr unsigned resultTag = 1;
134  static constexpr unsigned abortTag = 2;
135  };
136 
137  /// Register a conversion function. A conversion function must be convertible
138  /// to any of the following forms(where `T` is a class derived from `Type`:
139  /// * std::optional<Type>(T)
140  /// - This form represents a 1-1 type conversion. It should return nullptr
141  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
142  /// the converter is allowed to try another conversion function to
143  /// perform the conversion.
144  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
145  /// - This form represents a 1-N type conversion. It should return
146  /// `failure` or `std::nullopt` to signify a failed conversion. If the
147  /// new set of types is empty, the type is removed and any usages of the
148  /// existing value are expected to be removed during conversion. If
149  /// `std::nullopt` is returned, the converter is allowed to try another
150  /// conversion function to perform the conversion.
151  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
152  /// ArrayRef<Type>)
153  /// - This form represents a 1-N type conversion supporting recursive
154  /// types. The first two arguments and the return value are the same as
155  /// for the regular 1-N form. The third argument is contains is the
156  /// "call stack" of the recursive conversion: it contains the list of
157  /// types currently being converted, with the current type being the
158  /// last one. If it is present more than once in the list, the
159  /// conversion concerns a recursive type.
160  /// Note: When attempting to convert a type, e.g. via 'convertType', the
161  /// mostly recently added conversions will be invoked first.
162  template <typename FnT, typename T = typename llvm::function_traits<
163  std::decay_t<FnT>>::template arg_t<0>>
164  void addConversion(FnT &&callback) {
165  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
166  }
167 
168  /// Register a materialization function, which must be convertible to the
169  /// following form:
170  /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
171  /// where `T` is any subclass of `Type`. This function is responsible for
172  /// creating an operation, using the OpBuilder and Location provided, that
173  /// "casts" a range of values into a single value of the given type `T`. It
174  /// must return a Value of the converted type on success, an `std::nullopt` if
175  /// it failed but other materialization can be attempted, and `nullptr` on
176  /// unrecoverable failure. It will only be called for (sub)types of `T`.
177  /// Materialization functions must be provided when a type conversion may
178  /// persist after the conversion has finished.
179  ///
180  /// This method registers a materialization that will be called when
181  /// converting an illegal block argument type, to a legal type.
182  template <typename FnT, typename T = typename llvm::function_traits<
183  std::decay_t<FnT>>::template arg_t<1>>
184  void addArgumentMaterialization(FnT &&callback) {
185  argumentMaterializations.emplace_back(
186  wrapMaterialization<T>(std::forward<FnT>(callback)));
187  }
188  /// This method registers a materialization that will be called when
189  /// converting a legal type to an illegal source type. This is used when
190  /// conversions to an illegal type must persist beyond the main conversion.
191  template <typename FnT, typename T = typename llvm::function_traits<
192  std::decay_t<FnT>>::template arg_t<1>>
193  void addSourceMaterialization(FnT &&callback) {
194  sourceMaterializations.emplace_back(
195  wrapMaterialization<T>(std::forward<FnT>(callback)));
196  }
197  /// This method registers a materialization that will be called when
198  /// converting type from an illegal, or source, type to a legal type.
199  template <typename FnT, typename T = typename llvm::function_traits<
200  std::decay_t<FnT>>::template arg_t<1>>
201  void addTargetMaterialization(FnT &&callback) {
202  targetMaterializations.emplace_back(
203  wrapMaterialization<T>(std::forward<FnT>(callback)));
204  }
205 
206  /// Register a conversion function for attributes within types. Type
207  /// converters may call this function in order to allow hoking into the
208  /// translation of attributes that exist within types. For example, a type
209  /// converter for the `memref` type could use these conversions to convert
210  /// memory spaces or layouts in an extensible way.
211  ///
212  /// The conversion functions take a non-null Type or subclass of Type and a
213  /// non-null Attribute (or subclass of Attribute), and returns a
214  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
215  /// which may be `nullptr`, representing the conversion's success,
216  /// `AttributeConversionResult::na()` (the default empty value), indicating
217  /// that the conversion function did not apply and that further conversion
218  /// functions should be checked, or `AttributeConversionResult::abort()`
219  /// indicating that the conversion process should be aborted.
220  ///
221  /// Registered conversion functions are callled in the reverse of the order in
222  /// which they were registered.
223  template <
224  typename FnT,
225  typename T =
226  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
227  typename A =
228  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
229  void addTypeAttributeConversion(FnT &&callback) {
230  registerTypeAttributeConversion(
231  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
232  }
233 
234  /// Convert the given type. This function should return failure if no valid
235  /// conversion exists, success otherwise. If the new set of types is empty,
236  /// the type is removed and any usages of the existing value are expected to
237  /// be removed during conversion.
239 
240  /// This hook simplifies defining 1-1 type conversions. This function returns
241  /// the type to convert to on success, and a null type on failure.
242  Type convertType(Type t) const;
243 
244  /// Attempts a 1-1 type conversion, expecting the result type to be
245  /// `TargetType`. Returns the converted type cast to `TargetType` on success,
246  /// and a null type on conversion or cast failure.
247  template <typename TargetType> TargetType convertType(Type t) const {
248  return dyn_cast_or_null<TargetType>(convertType(t));
249  }
250 
251  /// Convert the given set of types, filling 'results' as necessary. This
252  /// returns failure if the conversion of any of the types fails, success
253  /// otherwise.
255  SmallVectorImpl<Type> &results) const;
256 
257  /// Return true if the given type is legal for this type converter, i.e. the
258  /// type converts to itself.
259  bool isLegal(Type type) const;
260 
261  /// Return true if all of the given types are legal for this type converter.
262  template <typename RangeT>
263  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
264  !std::is_convertible<RangeT, Operation *>::value,
265  bool>
266  isLegal(RangeT &&range) const {
267  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
268  }
269  /// Return true if the given operation has legal operand and result types.
270  bool isLegal(Operation *op) const;
271 
272  /// Return true if the types of block arguments within the region are legal.
273  bool isLegal(Region *region) const;
274 
275  /// Return true if the inputs and outputs of the given function type are
276  /// legal.
277  bool isSignatureLegal(FunctionType ty) const;
278 
279  /// This method allows for converting a specific argument of a signature. It
280  /// takes as inputs the original argument input number, type.
281  /// On success, it populates 'result' with any new mappings.
282  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
283  SignatureConversion &result) const;
285  SignatureConversion &result,
286  unsigned origInputOffset = 0) const;
287 
288  /// This function converts the type signature of the given block, by invoking
289  /// 'convertSignatureArg' for each argument. This function should return a
290  /// valid conversion for the signature on success, std::nullopt otherwise.
291  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
292 
293  /// Materialize a conversion from a set of types into one result type by
294  /// generating a cast sequence of some kind. See the respective
295  /// `add*Materialization` for more information on the context for these
296  /// methods.
298  Type resultType,
299  ValueRange inputs) const {
300  return materializeConversion(argumentMaterializations, builder, loc,
301  resultType, inputs);
302  }
304  Type resultType, ValueRange inputs) const {
305  return materializeConversion(sourceMaterializations, builder, loc,
306  resultType, inputs);
307  }
309  Type resultType, ValueRange inputs) const {
310  return materializeConversion(targetMaterializations, builder, loc,
311  resultType, inputs);
312  }
313 
314  /// Convert an attribute present `attr` from within the type `type` using
315  /// the registered conversion functions. If no applicable conversion has been
316  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
317  /// is a valid return value for this function.
318  std::optional<Attribute> convertTypeAttribute(Type type,
319  Attribute attr) const;
320 
321 private:
322  /// The signature of the callback used to convert a type. If the new set of
323  /// types is empty, the type is removed and any usages of the existing value
324  /// are expected to be removed during conversion.
325  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
327 
328  /// The signature of the callback used to materialize a conversion.
329  using MaterializationCallbackFn = std::function<std::optional<Value>(
331 
332  /// The signature of the callback used to convert a type attribute.
333  using TypeAttributeConversionCallbackFn =
334  std::function<AttributeConversionResult(Type, Attribute)>;
335 
336  /// Attempt to materialize a conversion using one of the provided
337  /// materialization functions.
338  Value
339  materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
340  OpBuilder &builder, Location loc, Type resultType,
341  ValueRange inputs) const;
342 
343  /// Generate a wrapper for the given callback. This allows for accepting
344  /// different callback forms, that all compose into a single version.
345  /// With callback of form: `std::optional<Type>(T)`
346  template <typename T, typename FnT>
347  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
348  wrapCallback(FnT &&callback) const {
349  return wrapCallback<T>([callback = std::forward<FnT>(callback)](
350  T type, SmallVectorImpl<Type> &results) {
351  if (std::optional<Type> resultOpt = callback(type)) {
352  bool wasSuccess = static_cast<bool>(*resultOpt);
353  if (wasSuccess)
354  results.push_back(*resultOpt);
355  return std::optional<LogicalResult>(success(wasSuccess));
356  }
357  return std::optional<LogicalResult>();
358  });
359  }
360  /// With callback of form: `std::optional<LogicalResult>(
361  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
362  template <typename T, typename FnT>
363  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
364  ConversionCallbackFn>
365  wrapCallback(FnT &&callback) const {
366  return [callback = std::forward<FnT>(callback)](
367  Type type,
368  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
369  T derivedType = dyn_cast<T>(type);
370  if (!derivedType)
371  return std::nullopt;
372  return callback(derivedType, results);
373  };
374  }
375 
376  /// Register a type conversion.
377  void registerConversion(ConversionCallbackFn callback) {
378  conversions.emplace_back(std::move(callback));
379  cachedDirectConversions.clear();
380  cachedMultiConversions.clear();
381  }
382 
383  /// Generate a wrapper for the given materialization callback. The callback
384  /// may take any subclass of `Type` and the wrapper will check for the target
385  /// type to be of the expected class before calling the callback.
386  template <typename T, typename FnT>
387  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
388  return [callback = std::forward<FnT>(callback)](
389  OpBuilder &builder, Type resultType, ValueRange inputs,
390  Location loc) -> std::optional<Value> {
391  if (T derivedType = dyn_cast<T>(resultType))
392  return callback(builder, derivedType, inputs, loc);
393  return std::nullopt;
394  };
395  }
396 
397  /// Generate a wrapper for the given memory space conversion callback. The
398  /// callback may take any subclass of `Attribute` and the wrapper will check
399  /// for the target attribute to be of the expected class before calling the
400  /// callback.
401  template <typename T, typename A, typename FnT>
402  TypeAttributeConversionCallbackFn
403  wrapTypeAttributeConversion(FnT &&callback) const {
404  return [callback = std::forward<FnT>(callback)](
405  Type type, Attribute attr) -> AttributeConversionResult {
406  if (T derivedType = dyn_cast<T>(type)) {
407  if (A derivedAttr = dyn_cast_or_null<A>(attr))
408  return callback(derivedType, derivedAttr);
409  }
411  };
412  }
413 
414  /// Register a memory space conversion, clearing caches.
415  void
416  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
417  typeAttributeConversions.emplace_back(std::move(callback));
418  // Clear type conversions in case a memory space is lingering inside.
419  cachedDirectConversions.clear();
420  cachedMultiConversions.clear();
421  }
422 
423  /// The set of registered conversion functions.
424  SmallVector<ConversionCallbackFn, 4> conversions;
425 
426  /// The list of registered materialization functions.
427  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
428  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
429  SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
430 
431  /// The list of registered type attribute conversion functions.
432  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
433 
434  /// A set of cached conversions to avoid recomputing in the common case.
435  /// Direct 1-1 conversions are the most common, so this cache stores the
436  /// successful 1-1 conversions as well as all failed conversions.
437  mutable DenseMap<Type, Type> cachedDirectConversions;
438  /// This cache stores the successful 1->N conversions, where N != 1.
439  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
440  /// A mutex used for cache access
441  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
442 };
443 
444 //===----------------------------------------------------------------------===//
445 // Conversion Patterns
446 //===----------------------------------------------------------------------===//
447 
448 /// Base class for the conversion patterns. This pattern class enables type
449 /// conversions, and other uses specific to the conversion framework. As such,
450 /// patterns of this type can only be used with the 'apply*' methods below.
452 public:
453  /// Hook for derived classes to implement rewriting. `op` is the (first)
454  /// operation matched by the pattern, `operands` is a list of the rewritten
455  /// operand values that are passed to `op`, `rewriter` can be used to emit the
456  /// new operations. This function should not fail. If some specific cases of
457  /// the operation are not supported, these cases should not be matched.
458  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
459  ConversionPatternRewriter &rewriter) const {
460  llvm_unreachable("unimplemented rewrite");
461  }
462 
463  /// Hook for derived classes to implement combined matching and rewriting.
464  virtual LogicalResult
466  ConversionPatternRewriter &rewriter) const {
467  if (failed(match(op)))
468  return failure();
469  rewrite(op, operands, rewriter);
470  return success();
471  }
472 
473  /// Attempt to match and rewrite the IR root at the specified operation.
475  PatternRewriter &rewriter) const final;
476 
477  /// Return the type converter held by this pattern, or nullptr if the pattern
478  /// does not require type conversion.
479  const TypeConverter *getTypeConverter() const { return typeConverter; }
480 
481  template <typename ConverterTy>
482  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
483  const ConverterTy *>
485  return static_cast<const ConverterTy *>(typeConverter);
486  }
487 
488 protected:
489  /// See `RewritePattern::RewritePattern` for information on the other
490  /// available constructors.
491  using RewritePattern::RewritePattern;
492  /// Construct a conversion pattern with the given converter, and forward the
493  /// remaining arguments to RewritePattern.
494  template <typename... Args>
496  : RewritePattern(std::forward<Args>(args)...),
498 
499 protected:
500  /// An optional type converter for use by this pattern.
501  const TypeConverter *typeConverter = nullptr;
502 
503 private:
505 };
506 
507 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
508 /// matching and rewriting against an instance of a derived operation class as
509 /// opposed to a raw Operation.
510 template <typename SourceOp>
512 public:
513  using OpAdaptor = typename SourceOp::Adaptor;
514 
516  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
518  PatternBenefit benefit = 1)
519  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
520  context) {}
521 
522  /// Wrappers around the ConversionPattern methods that pass the derived op
523  /// type.
524  LogicalResult match(Operation *op) const final {
525  return match(cast<SourceOp>(op));
526  }
527  void rewrite(Operation *op, ArrayRef<Value> operands,
528  ConversionPatternRewriter &rewriter) const final {
529  auto sourceOp = cast<SourceOp>(op);
530  rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
531  }
534  ConversionPatternRewriter &rewriter) const final {
535  auto sourceOp = cast<SourceOp>(op);
536  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
537  }
538 
539  /// Rewrite and Match methods that operate on the SourceOp type. These must be
540  /// overridden by the derived pattern class.
541  virtual LogicalResult match(SourceOp op) const {
542  llvm_unreachable("must override match or matchAndRewrite");
543  }
544  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
545  ConversionPatternRewriter &rewriter) const {
546  llvm_unreachable("must override matchAndRewrite or a rewrite method");
547  }
548  virtual LogicalResult
549  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
550  ConversionPatternRewriter &rewriter) const {
551  if (failed(match(op)))
552  return failure();
553  rewrite(op, adaptor, rewriter);
554  return success();
555  }
556 
557 private:
559 };
560 
561 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
562 /// allows for matching and rewriting against an instance of an OpInterface
563 /// class as opposed to a raw Operation.
564 template <typename SourceOp>
566 public:
569  SourceOp::getInterfaceID(), benefit, context) {}
571  MLIRContext *context, PatternBenefit benefit = 1)
573  SourceOp::getInterfaceID(), benefit, context) {}
574 
575  /// Wrappers around the ConversionPattern methods that pass the derived op
576  /// type.
577  void rewrite(Operation *op, ArrayRef<Value> operands,
578  ConversionPatternRewriter &rewriter) const final {
579  rewrite(cast<SourceOp>(op), operands, rewriter);
580  }
583  ConversionPatternRewriter &rewriter) const final {
584  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
585  }
586 
587  /// Rewrite and Match methods that operate on the SourceOp type. These must be
588  /// overridden by the derived pattern class.
589  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
590  ConversionPatternRewriter &rewriter) const {
591  llvm_unreachable("must override matchAndRewrite or a rewrite method");
592  }
593  virtual LogicalResult
594  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
595  ConversionPatternRewriter &rewriter) const {
596  if (failed(match(op)))
597  return failure();
598  rewrite(op, operands, rewriter);
599  return success();
600  }
601 
602 private:
604 };
605 
606 /// Add a pattern to the given pattern list to convert the signature of a
607 /// FunctionOpInterface op with the given type converter. This only supports
608 /// ops which use FunctionType to represent their type.
610  StringRef functionLikeOpName, RewritePatternSet &patterns,
611  const TypeConverter &converter);
612 
613 template <typename FuncOpT>
615  RewritePatternSet &patterns, const TypeConverter &converter) {
616  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
617  patterns, converter);
618 }
619 
621  RewritePatternSet &patterns, const TypeConverter &converter);
622 
623 //===----------------------------------------------------------------------===//
624 // Conversion PatternRewriter
625 //===----------------------------------------------------------------------===//
626 
627 namespace detail {
628 struct ConversionPatternRewriterImpl;
629 } // namespace detail
630 
631 /// This class implements a pattern rewriter for use with ConversionPatterns. It
632 /// extends the base PatternRewriter and provides special conversion specific
633 /// hooks.
635  public RewriterBase::Listener {
636 public:
637  explicit ConversionPatternRewriter(MLIRContext *ctx);
639 
640  /// Apply a signature conversion to the entry block of the given region. This
641  /// replaces the entry block with a new block containing the updated
642  /// signature. The new entry block to the region is returned for convenience.
643  ///
644  /// If provided, `converter` will be used for any materializations.
645  Block *
648  const TypeConverter *converter = nullptr);
649 
650  /// Convert the types of block arguments within the given region. This
651  /// replaces each block with a new block containing the updated signature. The
652  /// entry block may have a special conversion if `entryConversion` is
653  /// provided. On success, the new entry block to the region is returned for
654  /// convenience. Otherwise, failure is returned.
656  Region *region, const TypeConverter &converter,
657  TypeConverter::SignatureConversion *entryConversion = nullptr);
658 
659  /// Convert the types of block arguments within the given region except for
660  /// the entry region. This replaces each non-entry block with a new block
661  /// containing the updated signature.
662  ///
663  /// If special conversion behavior is needed for the non-entry blocks (for
664  /// example, we need to convert only a subset of a BB arguments), such
665  /// behavior can be specified in blockConversions.
667  Region *region, const TypeConverter &converter,
669 
670  /// Replace all the uses of the block argument `from` with value `to`.
672 
673  /// Return the converted value of 'key' with a type defined by the type
674  /// converter of the currently executing pattern. Return nullptr in the case
675  /// of failure, the remapped value otherwise.
677 
678  /// Return the converted values that replace 'keys' with types defined by the
679  /// type converter of the currently executing pattern. Returns failure if the
680  /// remap failed, success otherwise.
682  SmallVectorImpl<Value> &results);
683 
684  //===--------------------------------------------------------------------===//
685  // PatternRewriter Hooks
686  //===--------------------------------------------------------------------===//
687 
688  /// Indicate that the conversion rewriter can recover from rewrite failure.
689  /// Recovery is supported via rollback, allowing for continued processing of
690  /// patterns even if a failure is encountered during the rewrite step.
691  bool canRecoverFromRewriteFailure() const override { return true; }
692 
693  /// PatternRewriter hook for replacing an operation when the given functor
694  /// returns "true".
695  void replaceOpWithIf(
696  Operation *op, ValueRange newValues, bool *allUsesReplaced,
697  llvm::unique_function<bool(OpOperand &) const> functor) override;
698 
699  /// PatternRewriter hook for replacing an operation.
700  void replaceOp(Operation *op, ValueRange newValues) override;
701 
702  /// PatternRewriter hook for replacing an operation.
703  void replaceOp(Operation *op, Operation *newOp) override;
704 
705  /// PatternRewriter hook for erasing a dead operation. The uses of this
706  /// operation *must* be made dead by the end of the conversion process,
707  /// otherwise an assert will be issued.
708  void eraseOp(Operation *op) override;
709 
710  /// PatternRewriter hook for erase all operations in a block. This is not yet
711  /// implemented for dialect conversion.
712  void eraseBlock(Block *block) override;
713 
714  /// PatternRewriter hook creating a new block.
715  void notifyBlockCreated(Block *block) override;
716 
717  /// PatternRewriter hook for splitting a block into two parts.
718  Block *splitBlock(Block *block, Block::iterator before) override;
719 
720  /// PatternRewriter hook for inlining the ops of a block into another block.
721  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
722  ValueRange argValues = std::nullopt) override;
724 
725  /// PatternRewriter hook for moving blocks out of a region.
726  void inlineRegionBefore(Region &region, Region &parent,
727  Region::iterator before) override;
729 
730  /// PatternRewriter hook for cloning blocks of one region into another. The
731  /// given region to clone *must* not have been modified as part of conversion
732  /// yet, i.e. it must be within an operation that is either in the process of
733  /// conversion, or has not yet been converted.
734  void cloneRegionBefore(Region &region, Region &parent,
735  Region::iterator before, IRMapping &mapping) override;
737 
738  /// PatternRewriter hook for inserting a new operation.
739  void notifyOperationInserted(Operation *op) override;
740 
741  /// PatternRewriter hook for updating the root operation in-place.
742  /// Note: These methods only track updates to the top-level operation itself,
743  /// and not nested regions. Updates to regions will still require notification
744  /// through other more specific hooks above.
745  void startRootUpdate(Operation *op) override;
746 
747  /// PatternRewriter hook for updating the root operation in-place.
748  void finalizeRootUpdate(Operation *op) override;
749 
750  /// PatternRewriter hook for updating the root operation in-place.
751  void cancelRootUpdate(Operation *op) override;
752 
753  /// PatternRewriter hook for notifying match failure reasons.
756  function_ref<void(Diagnostic &)> reasonCallback) override;
758 
759  /// Return a reference to the internal implementation.
761 
762 private:
765 
766  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
767 };
768 
769 //===----------------------------------------------------------------------===//
770 // ConversionTarget
771 //===----------------------------------------------------------------------===//
772 
773 /// This class describes a specific conversion target.
775 public:
776  /// This enumeration corresponds to the specific action to take when
777  /// considering an operation legal for this conversion target.
778  enum class LegalizationAction {
779  /// The target supports this operation.
780  Legal,
781 
782  /// This operation has dynamic legalization constraints that must be checked
783  /// by the target.
784  Dynamic,
785 
786  /// The target explicitly does not support this operation.
787  Illegal,
788  };
789 
790  /// A structure containing additional information describing a specific legal
791  /// operation instance.
792  struct LegalOpDetails {
793  /// A flag that indicates if this operation is 'recursively' legal. This
794  /// means that if an operation is legal, either statically or dynamically,
795  /// all of the operations nested within are also considered legal.
796  bool isRecursivelyLegal = false;
797  };
798 
799  /// The signature of the callback used to determine if an operation is
800  /// dynamically legal on the target.
802  std::function<std::optional<bool>(Operation *)>;
803 
804  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
805  virtual ~ConversionTarget() = default;
806 
807  //===--------------------------------------------------------------------===//
808  // Legality Registration
809  //===--------------------------------------------------------------------===//
810 
811  /// Register a legality action for the given operation.
813  template <typename OpT>
815  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
816  }
817 
818  /// Register the given operations as legal.
821  }
822  template <typename OpT>
823  void addLegalOp() {
824  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
825  }
826  template <typename OpT, typename OpT2, typename... OpTs>
827  void addLegalOp() {
828  addLegalOp<OpT>();
829  addLegalOp<OpT2, OpTs...>();
830  }
831 
832  /// Register the given operation as dynamically legal and set the dynamic
833  /// legalization callback to the one provided.
835  const DynamicLegalityCallbackFn &callback) {
837  setLegalityCallback(op, callback);
838  }
839  template <typename OpT>
841  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
842  callback);
843  }
844  template <typename OpT, typename OpT2, typename... OpTs>
846  addDynamicallyLegalOp<OpT>(callback);
847  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
848  }
849  template <typename OpT, class Callable>
850  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
851  addDynamicallyLegalOp(Callable &&callback) {
852  addDynamicallyLegalOp<OpT>(
853  [=](Operation *op) { return callback(cast<OpT>(op)); });
854  }
855 
856  /// Register the given operation as illegal, i.e. this operation is known to
857  /// not be supported by this target.
860  }
861  template <typename OpT>
862  void addIllegalOp() {
863  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
864  }
865  template <typename OpT, typename OpT2, typename... OpTs>
866  void addIllegalOp() {
867  addIllegalOp<OpT>();
868  addIllegalOp<OpT2, OpTs...>();
869  }
870 
871  /// Mark an operation, that *must* have either been set as `Legal` or
872  /// `DynamicallyLegal`, as being recursively legal. This means that in
873  /// addition to the operation itself, all of the operations nested within are
874  /// also considered legal. An optional dynamic legality callback may be
875  /// provided to mark subsets of legal instances as recursively legal.
877  const DynamicLegalityCallbackFn &callback);
878  template <typename OpT>
880  OperationName opName(OpT::getOperationName(), &ctx);
881  markOpRecursivelyLegal(opName, callback);
882  }
883  template <typename OpT, typename OpT2, typename... OpTs>
885  markOpRecursivelyLegal<OpT>(callback);
886  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
887  }
888  template <typename OpT, class Callable>
889  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
890  markOpRecursivelyLegal(Callable &&callback) {
891  markOpRecursivelyLegal<OpT>(
892  [=](Operation *op) { return callback(cast<OpT>(op)); });
893  }
894 
895  /// Register a legality action for the given dialects.
896  void setDialectAction(ArrayRef<StringRef> dialectNames,
897  LegalizationAction action);
898 
899  /// Register the operations of the given dialects as legal.
900  template <typename... Names>
901  void addLegalDialect(StringRef name, Names... names) {
902  SmallVector<StringRef, 2> dialectNames({name, names...});
904  }
905  template <typename... Args>
907  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
909  }
910 
911  /// Register the operations of the given dialects as dynamically legal, i.e.
912  /// requiring custom handling by the callback.
913  template <typename... Names>
915  StringRef name, Names... names) {
916  SmallVector<StringRef, 2> dialectNames({name, names...});
918  setLegalityCallback(dialectNames, callback);
919  }
920  template <typename... Args>
922  addDynamicallyLegalDialect(std::move(callback),
923  Args::getDialectNamespace()...);
924  }
925 
926  /// Register unknown operations as dynamically legal. For operations(and
927  /// dialects) that do not have a set legalization action, treat them as
928  /// dynamically legal and invoke the given callback.
930  setLegalityCallback(fn);
931  }
932 
933  /// Register the operations of the given dialects as illegal, i.e.
934  /// operations of this dialect are not supported by the target.
935  template <typename... Names>
936  void addIllegalDialect(StringRef name, Names... names) {
937  SmallVector<StringRef, 2> dialectNames({name, names...});
939  }
940  template <typename... Args>
942  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
944  }
945 
946  //===--------------------------------------------------------------------===//
947  // Legality Querying
948  //===--------------------------------------------------------------------===//
949 
950  /// Get the legality action for the given operation.
951  std::optional<LegalizationAction> getOpAction(OperationName op) const;
952 
953  /// If the given operation instance is legal on this target, a structure
954  /// containing legality information is returned. If the operation is not
955  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
956  /// legality wasn't registered by user or dynamic legality callbacks returned
957  /// None.
958  ///
959  /// Note: Legality is actually a 4-state: Legal(recursive=true),
960  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
961  /// either as Legal or Illegal depending on context.
962  std::optional<LegalOpDetails> isLegal(Operation *op) const;
963 
964  /// Returns true is operation instance is illegal on this target. Returns
965  /// false if operation is legal, operation legality wasn't registered by user
966  /// or dynamic legality callbacks returned None.
967  bool isIllegal(Operation *op) const;
968 
969 private:
970  /// Set the dynamic legality callback for the given operation.
971  void setLegalityCallback(OperationName name,
972  const DynamicLegalityCallbackFn &callback);
973 
974  /// Set the dynamic legality callback for the given dialects.
975  void setLegalityCallback(ArrayRef<StringRef> dialects,
976  const DynamicLegalityCallbackFn &callback);
977 
978  /// Set the dynamic legality callback for the unknown ops.
979  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
980 
981  /// The set of information that configures the legalization of an operation.
982  struct LegalizationInfo {
983  /// The legality action this operation was given.
985 
986  /// If some legal instances of this operation may also be recursively legal.
987  bool isRecursivelyLegal = false;
988 
989  /// The legality callback if this operation is dynamically legal.
990  DynamicLegalityCallbackFn legalityFn;
991  };
992 
993  /// Get the legalization information for the given operation.
994  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
995 
996  /// A deterministic mapping of operation name and its respective legality
997  /// information.
998  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
999 
1000  /// A set of legality callbacks for given operation names that are used to
1001  /// check if an operation instance is recursively legal.
1002  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1003 
1004  /// A deterministic mapping of dialect name to the specific legality action to
1005  /// take.
1006  llvm::StringMap<LegalizationAction> legalDialects;
1007 
1008  /// A set of dynamic legality callbacks for given dialect names.
1009  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1010 
1011  /// An optional legality callback for unknown operations.
1012  DynamicLegalityCallbackFn unknownLegalityFn;
1013 
1014  /// The current context this target applies to.
1015  MLIRContext &ctx;
1016 };
1017 
1018 //===----------------------------------------------------------------------===//
1019 // PDL Configuration
1020 //===----------------------------------------------------------------------===//
1021 
1022 /// A PDL configuration that is used to supported dialect conversion
1023 /// functionality.
1025  : public PDLPatternConfigBase<PDLConversionConfig> {
1026 public:
1027  PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1028  ~PDLConversionConfig() final = default;
1029 
1030  /// Return the type converter used by this configuration, which may be nullptr
1031  /// if no type conversions are expected.
1032  const TypeConverter *getTypeConverter() const { return converter; }
1033 
1034  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1035  /// pattern.
1036  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1037  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1038 
1039 private:
1040  /// An optional type converter to use for the pattern.
1041  const TypeConverter *converter;
1042 };
1043 
1044 /// Register the dialect conversion PDL functions with the given pattern set.
1045 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1046 
1047 //===----------------------------------------------------------------------===//
1048 // Op Conversion Entry Points
1049 //===----------------------------------------------------------------------===//
1050 
1051 /// Below we define several entry points for operation conversion. It is
1052 /// important to note that the patterns provided to the conversion framework may
1053 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1054 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1055 /// the use of the PatternRewriter.
1056 
1057 /// Apply a partial conversion on the given operations and all nested
1058 /// operations. This method converts as many operations to the target as
1059 /// possible, ignoring operations that failed to legalize. This method only
1060 /// returns failure if there ops explicitly marked as illegal. If an
1061 /// `unconvertedOps` set is provided, all operations that are found not to be
1062 /// legalizable to the given `target` are placed within that set. (Note that if
1063 /// there is an op explicitly marked as illegal, the conversion terminates and
1064 /// the `unconvertedOps` set will not necessarily be complete.)
1065 LogicalResult
1066 applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
1067  const FrozenRewritePatternSet &patterns,
1068  DenseSet<Operation *> *unconvertedOps = nullptr);
1069 LogicalResult
1070 applyPartialConversion(Operation *op, const ConversionTarget &target,
1071  const FrozenRewritePatternSet &patterns,
1072  DenseSet<Operation *> *unconvertedOps = nullptr);
1073 
1074 /// Apply a complete conversion on the given operations, and all nested
1075 /// operations. This method returns failure if the conversion of any operation
1076 /// fails, or if there are unreachable blocks in any of the regions nested
1077 /// within 'ops'.
1078 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1079  const ConversionTarget &target,
1080  const FrozenRewritePatternSet &patterns);
1081 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1082  const FrozenRewritePatternSet &patterns);
1083 
1084 /// Apply an analysis conversion on the given operations, and all nested
1085 /// operations. This method analyzes which operations would be successfully
1086 /// converted to the target if a conversion was applied. All operations that
1087 /// were found to be legalizable to the given 'target' are placed within the
1088 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
1089 /// operations on success and only pre-existing operations are added to the set.
1090 /// This method only returns failure if there are unreachable blocks in any of
1091 /// the regions nested within 'ops'. There's an additional argument
1092 /// `notifyCallback` which is used for collecting match failure diagnostics
1093 /// generated during the conversion. Diagnostics are only reported to this
1094 /// callback may only be available in debug mode.
1095 LogicalResult applyAnalysisConversion(
1096  ArrayRef<Operation *> ops, ConversionTarget &target,
1097  const FrozenRewritePatternSet &patterns,
1098  DenseSet<Operation *> &convertedOps,
1099  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1100 LogicalResult applyAnalysisConversion(
1101  Operation *op, ConversionTarget &target,
1102  const FrozenRewritePatternSet &patterns,
1103  DenseSet<Operation *> &convertedOps,
1104  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1105 } // namespace mlir
1106 
1107 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:310
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor) override
PatternRewriter hook for replacing an operation when the given functor returns "true".
void notifyBlockCreated(Block *block) override
PatternRewriter hook creating a new block.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void notifyOperationInserted(Operation *op) override
PatternRewriter hook for inserting a new operation.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping) override
PatternRewriter hook for cloning blocks of one region into another.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:301
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:305
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
This class represents an operand of an operation.
Definition: Value.h:261
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
const TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
PDLConversionConfig(const TypeConverter *converter)
~PDLConversionConfig() final=default
This class provides a base class for users implementing a type of pattern configuration.
Definition: PatternMatch.h:962
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType::iterator iterator
Definition: Region.h:52
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
TypeConverter()=default
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range) const
Return true if all of the given types are legal for this type converter.
TargetType convertType(Type t) const
Attempts a 1-1 type conversion, expecting the result type to be TargetType.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
TypeConverter(const TypeConverter &other)
TypeConverter & operator=(const TypeConverter &other)
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
virtual ~TypeConverter()=default
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
@ Type
An inlay hint that for a type annotation.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:147
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:163
This struct represents a range of new types or a single value that remaps an existing signature input...