MLIR  17.0.0git
DialectConversion.h
Go to the documentation of this file.
1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/StringMap.h"
19 #include <type_traits>
20 
21 namespace mlir {
22 
23 // Forward declarations.
24 class Attribute;
25 class Block;
26 class ConversionPatternRewriter;
27 class MLIRContext;
28 class Operation;
29 class Type;
30 class Value;
31 
32 //===----------------------------------------------------------------------===//
33 // Type Conversion
34 //===----------------------------------------------------------------------===//
35 
36 /// Type conversion class. Specific conversions and materializations can be
37 /// registered using addConversion and addMaterialization, respectively.
39 public:
40  /// This class provides all of the information necessary to convert a type
41  /// signature.
43  public:
44  SignatureConversion(unsigned numOrigInputs)
45  : remappedInputs(numOrigInputs) {}
46 
47  /// This struct represents a range of new types or a single value that
48  /// remaps an existing signature input.
49  struct InputMapping {
50  size_t inputNo, size;
52  };
53 
54  /// Return the argument types for the new signature.
55  ArrayRef<Type> getConvertedTypes() const { return argTypes; }
56 
57  /// Get the input mapping for the given argument.
58  std::optional<InputMapping> getInputMapping(unsigned input) const {
59  return remappedInputs[input];
60  }
61 
62  //===------------------------------------------------------------------===//
63  // Conversion Hooks
64  //===------------------------------------------------------------------===//
65 
66  /// Remap an input of the original signature with a new set of types. The
67  /// new types are appended to the new signature conversion.
68  void addInputs(unsigned origInputNo, ArrayRef<Type> types);
69 
70  /// Append new input types to the signature conversion, this should only be
71  /// used if the new types are not intended to remap an existing input.
72  void addInputs(ArrayRef<Type> types);
73 
74  /// Remap an input of the original signature to another `replacement`
75  /// value. This drops the original argument.
76  void remapInput(unsigned origInputNo, Value replacement);
77 
78  private:
79  /// Remap an input of the original signature with a range of types in the
80  /// new signature.
81  void remapInput(unsigned origInputNo, unsigned newInputNo,
82  unsigned newInputCount = 1);
83 
84  /// The remapping information for each of the original arguments.
85  SmallVector<std::optional<InputMapping>, 4> remappedInputs;
86 
87  /// The set of new argument types.
88  SmallVector<Type, 4> argTypes;
89  };
90 
91  /// The general result of a type attribute conversion callback, allowing
92  /// for early termination. The default constructor creates the na case.
94  public:
95  constexpr AttributeConversionResult() : impl() {}
96  AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
97 
101 
102  bool hasResult() const;
103  bool isNa() const;
104  bool isAbort() const;
105 
106  Attribute getResult() const;
107 
108  private:
109  AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
110 
111  llvm::PointerIntPair<Attribute, 2> impl;
112  // Note that na is 0 so that we can use PointerIntPair's default
113  // constructor.
114  static constexpr unsigned naTag = 0;
115  static constexpr unsigned resultTag = 1;
116  static constexpr unsigned abortTag = 2;
117  };
118 
119  /// Register a conversion function. A conversion function must be convertible
120  /// to any of the following forms(where `T` is a class derived from `Type`:
121  /// * std::optional<Type>(T)
122  /// - This form represents a 1-1 type conversion. It should return nullptr
123  /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
124  /// the converter is allowed to try another conversion function to
125  /// perform the conversion.
126  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
127  /// - This form represents a 1-N type conversion. It should return
128  /// `failure` or `std::nullopt` to signify a failed conversion. If the
129  /// new set of types is empty, the type is removed and any usages of the
130  /// existing value are expected to be removed during conversion. If
131  /// `std::nullopt` is returned, the converter is allowed to try another
132  /// conversion function to perform the conversion.
133  /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
134  /// ArrayRef<Type>)
135  /// - This form represents a 1-N type conversion supporting recursive
136  /// types. The first two arguments and the return value are the same as
137  /// for the regular 1-N form. The third argument is contains is the
138  /// "call stack" of the recursive conversion: it contains the list of
139  /// types currently being converted, with the current type being the
140  /// last one. If it is present more than once in the list, the
141  /// conversion concerns a recursive type.
142  /// Note: When attempting to convert a type, e.g. via 'convertType', the
143  /// mostly recently added conversions will be invoked first.
144  template <typename FnT, typename T = typename llvm::function_traits<
145  std::decay_t<FnT>>::template arg_t<0>>
146  void addConversion(FnT &&callback) {
147  registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
148  }
149 
150  /// Register a materialization function, which must be convertible to the
151  /// following form:
152  /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
153  /// where `T` is any subclass of `Type`. This function is responsible for
154  /// creating an operation, using the OpBuilder and Location provided, that
155  /// "casts" a range of values into a single value of the given type `T`. It
156  /// must return a Value of the converted type on success, an `std::nullopt` if
157  /// it failed but other materialization can be attempted, and `nullptr` on
158  /// unrecoverable failure. It will only be called for (sub)types of `T`.
159  /// Materialization functions must be provided when a type conversion may
160  /// persist after the conversion has finished.
161  ///
162  /// This method registers a materialization that will be called when
163  /// converting an illegal block argument type, to a legal type.
164  template <typename FnT, typename T = typename llvm::function_traits<
165  std::decay_t<FnT>>::template arg_t<1>>
166  void addArgumentMaterialization(FnT &&callback) {
167  argumentMaterializations.emplace_back(
168  wrapMaterialization<T>(std::forward<FnT>(callback)));
169  }
170  /// This method registers a materialization that will be called when
171  /// converting a legal type to an illegal source type. This is used when
172  /// conversions to an illegal type must persist beyond the main conversion.
173  template <typename FnT, typename T = typename llvm::function_traits<
174  std::decay_t<FnT>>::template arg_t<1>>
175  void addSourceMaterialization(FnT &&callback) {
176  sourceMaterializations.emplace_back(
177  wrapMaterialization<T>(std::forward<FnT>(callback)));
178  }
179  /// This method registers a materialization that will be called when
180  /// converting type from an illegal, or source, type to a legal type.
181  template <typename FnT, typename T = typename llvm::function_traits<
182  std::decay_t<FnT>>::template arg_t<1>>
183  void addTargetMaterialization(FnT &&callback) {
184  targetMaterializations.emplace_back(
185  wrapMaterialization<T>(std::forward<FnT>(callback)));
186  }
187 
188  /// Register a conversion function for attributes within types. Type
189  /// converters may call this function in order to allow hoking into the
190  /// translation of attributes that exist within types. For example, a type
191  /// converter for the `memref` type could use these conversions to convert
192  /// memory spaces or layouts in an extensible way.
193  ///
194  /// The conversion functions take a non-null Type or subclass of Type and a
195  /// non-null Attribute (or subclass of Attribute), and returns a
196  /// `AttributeConversionResult`. This result can either contan an `Attribute`,
197  /// which may be `nullptr`, representing the conversion's success,
198  /// `AttributeConversionResult::na()` (the default empty value), indicating
199  /// that the conversion function did not apply and that further conversion
200  /// functions should be checked, or `AttributeConversionResult::abort()`
201  /// indicating that the conversion process should be aborted.
202  ///
203  /// Registered conversion functions are callled in the reverse of the order in
204  /// which they were registered.
205  template <
206  typename FnT,
207  typename T =
208  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
209  typename A =
210  typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
211  void addTypeAttributeConversion(FnT &&callback) {
212  registerTypeAttributeConversion(
213  wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
214  }
215 
216  /// Convert the given type. This function should return failure if no valid
217  /// conversion exists, success otherwise. If the new set of types is empty,
218  /// the type is removed and any usages of the existing value are expected to
219  /// be removed during conversion.
221 
222  /// This hook simplifies defining 1-1 type conversions. This function returns
223  /// the type to convert to on success, and a null type on failure.
224  Type convertType(Type t);
225 
226  /// Convert the given set of types, filling 'results' as necessary. This
227  /// returns failure if the conversion of any of the types fails, success
228  /// otherwise.
230 
231  /// Return true if the given type is legal for this type converter, i.e. the
232  /// type converts to itself.
233  bool isLegal(Type type);
234  /// Return true if all of the given types are legal for this type converter.
235  template <typename RangeT>
236  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
237  !std::is_convertible<RangeT, Operation *>::value,
238  bool>
239  isLegal(RangeT &&range) {
240  return llvm::all_of(range, [this](Type type) { return isLegal(type); });
241  }
242  /// Return true if the given operation has legal operand and result types.
243  bool isLegal(Operation *op);
244 
245  /// Return true if the types of block arguments within the region are legal.
246  bool isLegal(Region *region);
247 
248  /// Return true if the inputs and outputs of the given function type are
249  /// legal.
250  bool isSignatureLegal(FunctionType ty);
251 
252  /// This method allows for converting a specific argument of a signature. It
253  /// takes as inputs the original argument input number, type.
254  /// On success, it populates 'result' with any new mappings.
255  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
256  SignatureConversion &result);
258  SignatureConversion &result,
259  unsigned origInputOffset = 0);
260 
261  /// This function converts the type signature of the given block, by invoking
262  /// 'convertSignatureArg' for each argument. This function should return a
263  /// valid conversion for the signature on success, std::nullopt otherwise.
264  std::optional<SignatureConversion> convertBlockSignature(Block *block);
265 
266  /// Materialize a conversion from a set of types into one result type by
267  /// generating a cast sequence of some kind. See the respective
268  /// `add*Materialization` for more information on the context for these
269  /// methods.
271  Type resultType, ValueRange inputs) {
272  return materializeConversion(argumentMaterializations, builder, loc,
273  resultType, inputs);
274  }
276  Type resultType, ValueRange inputs) {
277  return materializeConversion(sourceMaterializations, builder, loc,
278  resultType, inputs);
279  }
281  Type resultType, ValueRange inputs) {
282  return materializeConversion(targetMaterializations, builder, loc,
283  resultType, inputs);
284  }
285 
286  /// Convert an attribute present `attr` from within the type `type` using
287  /// the registered conversion functions. If no applicable conversion has been
288  /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
289  /// is a valid return value for this function.
290  std::optional<Attribute> convertTypeAttribute(Type type, Attribute attr);
291 
292 private:
293  /// The signature of the callback used to convert a type. If the new set of
294  /// types is empty, the type is removed and any usages of the existing value
295  /// are expected to be removed during conversion.
296  using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
298 
299  /// The signature of the callback used to materialize a conversion.
300  using MaterializationCallbackFn = std::function<std::optional<Value>(
302 
303  /// The signature of the callback used to convert a type attribute.
304  using TypeAttributeConversionCallbackFn =
305  std::function<AttributeConversionResult(Type, Attribute)>;
306 
307  /// Attempt to materialize a conversion using one of the provided
308  /// materialization functions.
309  Value materializeConversion(
311  OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
312 
313  /// Generate a wrapper for the given callback. This allows for accepting
314  /// different callback forms, that all compose into a single version.
315  /// With callback of form: `std::optional<Type>(T)`
316  template <typename T, typename FnT>
317  std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
318  wrapCallback(FnT &&callback) {
319  return wrapCallback<T>(
320  [callback = std::forward<FnT>(callback)](
321  T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
322  if (std::optional<Type> resultOpt = callback(type)) {
323  bool wasSuccess = static_cast<bool>(*resultOpt);
324  if (wasSuccess)
325  results.push_back(*resultOpt);
326  return std::optional<LogicalResult>(success(wasSuccess));
327  }
328  return std::optional<LogicalResult>();
329  });
330  }
331  /// With callback of form: `std::optional<LogicalResult>(
332  /// T, SmallVectorImpl<Type> &)`.
333  template <typename T, typename FnT>
334  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
335  ConversionCallbackFn>
336  wrapCallback(FnT &&callback) {
337  return wrapCallback<T>(
338  [callback = std::forward<FnT>(callback)](
339  T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
340  return callback(type, results);
341  });
342  }
343  /// With callback of form: `std::optional<LogicalResult>(
344  /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
345  template <typename T, typename FnT>
346  std::enable_if_t<
347  std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, ArrayRef<Type>>,
348  ConversionCallbackFn>
349  wrapCallback(FnT &&callback) {
350  return [callback = std::forward<FnT>(callback)](
351  Type type, SmallVectorImpl<Type> &results,
352  ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
353  T derivedType = type.dyn_cast<T>();
354  if (!derivedType)
355  return std::nullopt;
356  return callback(derivedType, results, callStack);
357  };
358  }
359 
360  /// Register a type conversion.
361  void registerConversion(ConversionCallbackFn callback) {
362  conversions.emplace_back(std::move(callback));
363  cachedDirectConversions.clear();
364  cachedMultiConversions.clear();
365  }
366 
367  /// Generate a wrapper for the given materialization callback. The callback
368  /// may take any subclass of `Type` and the wrapper will check for the target
369  /// type to be of the expected class before calling the callback.
370  template <typename T, typename FnT>
371  MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
372  return [callback = std::forward<FnT>(callback)](
373  OpBuilder &builder, Type resultType, ValueRange inputs,
374  Location loc) -> std::optional<Value> {
375  if (T derivedType = resultType.dyn_cast<T>())
376  return callback(builder, derivedType, inputs, loc);
377  return std::nullopt;
378  };
379  }
380 
381  /// Generate a wrapper for the given memory space conversion callback. The
382  /// callback may take any subclass of `Attribute` and the wrapper will check
383  /// for the target attribute to be of the expected class before calling the
384  /// callback.
385  template <typename T, typename A, typename FnT>
386  TypeAttributeConversionCallbackFn
387  wrapTypeAttributeConversion(FnT &&callback) {
388  return [callback = std::forward<FnT>(callback)](
389  Type type, Attribute attr) -> AttributeConversionResult {
390  if (T derivedType = type.dyn_cast<T>()) {
391  if (A derivedAttr = attr.dyn_cast_or_null<A>())
392  return callback(derivedType, derivedAttr);
393  }
395  };
396  }
397 
398  /// Register a memory space conversion, clearing caches.
399  void
400  registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
401  typeAttributeConversions.emplace_back(std::move(callback));
402  // Clear type conversions in case a memory space is lingering inside.
403  cachedDirectConversions.clear();
404  cachedMultiConversions.clear();
405  }
406 
407  /// The set of registered conversion functions.
408  SmallVector<ConversionCallbackFn, 4> conversions;
409 
410  /// The list of registered materialization functions.
411  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
412  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
413  SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
414 
415  /// The list of registered type attribute conversion functions.
416  SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
417 
418  /// A set of cached conversions to avoid recomputing in the common case.
419  /// Direct 1-1 conversions are the most common, so this cache stores the
420  /// successful 1-1 conversions as well as all failed conversions.
421  DenseMap<Type, Type> cachedDirectConversions;
422  /// This cache stores the successful 1->N conversions, where N != 1.
423  DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
424 
425  /// Stores the types that are being converted in the case when convertType
426  /// is being called recursively to convert nested types.
427  SmallVector<Type, 2> conversionCallStack;
428 };
429 
430 //===----------------------------------------------------------------------===//
431 // Conversion Patterns
432 //===----------------------------------------------------------------------===//
433 
434 /// Base class for the conversion patterns. This pattern class enables type
435 /// conversions, and other uses specific to the conversion framework. As such,
436 /// patterns of this type can only be used with the 'apply*' methods below.
438 public:
439  /// Hook for derived classes to implement rewriting. `op` is the (first)
440  /// operation matched by the pattern, `operands` is a list of the rewritten
441  /// operand values that are passed to `op`, `rewriter` can be used to emit the
442  /// new operations. This function should not fail. If some specific cases of
443  /// the operation are not supported, these cases should not be matched.
444  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
445  ConversionPatternRewriter &rewriter) const {
446  llvm_unreachable("unimplemented rewrite");
447  }
448 
449  /// Hook for derived classes to implement combined matching and rewriting.
450  virtual LogicalResult
452  ConversionPatternRewriter &rewriter) const {
453  if (failed(match(op)))
454  return failure();
455  rewrite(op, operands, rewriter);
456  return success();
457  }
458 
459  /// Attempt to match and rewrite the IR root at the specified operation.
461  PatternRewriter &rewriter) const final;
462 
463  /// Return the type converter held by this pattern, or nullptr if the pattern
464  /// does not require type conversion.
466 
467  template <typename ConverterTy>
468  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
469  ConverterTy *>
471  return static_cast<ConverterTy *>(typeConverter);
472  }
473 
474 protected:
475  /// See `RewritePattern::RewritePattern` for information on the other
476  /// available constructors.
477  using RewritePattern::RewritePattern;
478  /// Construct a conversion pattern with the given converter, and forward the
479  /// remaining arguments to RewritePattern.
480  template <typename... Args>
482  : RewritePattern(std::forward<Args>(args)...),
484 
485 protected:
486  /// An optional type converter for use by this pattern.
488 
489 private:
491 };
492 
493 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
494 /// matching and rewriting against an instance of a derived operation class as
495 /// opposed to a raw Operation.
496 template <typename SourceOp>
498 public:
499  using OpAdaptor = typename SourceOp::Adaptor;
500 
502  : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
504  PatternBenefit benefit = 1)
505  : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
506  context) {}
507 
508  /// Wrappers around the ConversionPattern methods that pass the derived op
509  /// type.
510  LogicalResult match(Operation *op) const final {
511  return match(cast<SourceOp>(op));
512  }
513  void rewrite(Operation *op, ArrayRef<Value> operands,
514  ConversionPatternRewriter &rewriter) const final {
515  rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
516  rewriter);
517  }
520  ConversionPatternRewriter &rewriter) const final {
521  return matchAndRewrite(cast<SourceOp>(op),
522  OpAdaptor(operands, op->getAttrDictionary()),
523  rewriter);
524  }
525 
526  /// Rewrite and Match methods that operate on the SourceOp type. These must be
527  /// overridden by the derived pattern class.
528  virtual LogicalResult match(SourceOp op) const {
529  llvm_unreachable("must override match or matchAndRewrite");
530  }
531  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
532  ConversionPatternRewriter &rewriter) const {
533  llvm_unreachable("must override matchAndRewrite or a rewrite method");
534  }
535  virtual LogicalResult
536  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
537  ConversionPatternRewriter &rewriter) const {
538  if (failed(match(op)))
539  return failure();
540  rewrite(op, adaptor, rewriter);
541  return success();
542  }
543 
544 private:
546 };
547 
548 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
549 /// allows for matching and rewriting against an instance of an OpInterface
550 /// class as opposed to a raw Operation.
551 template <typename SourceOp>
553 public:
556  SourceOp::getInterfaceID(), benefit, context) {}
558  MLIRContext *context, PatternBenefit benefit = 1)
560  SourceOp::getInterfaceID(), benefit, context) {}
561 
562  /// Wrappers around the ConversionPattern methods that pass the derived op
563  /// type.
564  void rewrite(Operation *op, ArrayRef<Value> operands,
565  ConversionPatternRewriter &rewriter) const final {
566  rewrite(cast<SourceOp>(op), operands, rewriter);
567  }
570  ConversionPatternRewriter &rewriter) const final {
571  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
572  }
573 
574  /// Rewrite and Match methods that operate on the SourceOp type. These must be
575  /// overridden by the derived pattern class.
576  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
577  ConversionPatternRewriter &rewriter) const {
578  llvm_unreachable("must override matchAndRewrite or a rewrite method");
579  }
580  virtual LogicalResult
581  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
582  ConversionPatternRewriter &rewriter) const {
583  if (failed(match(op)))
584  return failure();
585  rewrite(op, operands, rewriter);
586  return success();
587  }
588 
589 private:
591 };
592 
593 /// Add a pattern to the given pattern list to convert the signature of a
594 /// FunctionOpInterface op with the given type converter. This only supports
595 /// ops which use FunctionType to represent their type.
597  StringRef functionLikeOpName, RewritePatternSet &patterns,
598  TypeConverter &converter);
599 
600 template <typename FuncOpT>
602  RewritePatternSet &patterns, TypeConverter &converter) {
603  populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
604  patterns, converter);
605 }
606 
608  RewritePatternSet &patterns, TypeConverter &converter);
609 
610 //===----------------------------------------------------------------------===//
611 // Conversion PatternRewriter
612 //===----------------------------------------------------------------------===//
613 
614 namespace detail {
615 struct ConversionPatternRewriterImpl;
616 } // namespace detail
617 
618 /// This class implements a pattern rewriter for use with ConversionPatterns. It
619 /// extends the base PatternRewriter and provides special conversion specific
620 /// hooks.
622  public RewriterBase::Listener {
623 public:
624  explicit ConversionPatternRewriter(MLIRContext *ctx);
626 
627  /// Apply a signature conversion to the entry block of the given region. This
628  /// replaces the entry block with a new block containing the updated
629  /// signature. The new entry block to the region is returned for convenience.
630  ///
631  /// If provided, `converter` will be used for any materializations.
632  Block *
635  TypeConverter *converter = nullptr);
636 
637  /// Convert the types of block arguments within the given region. This
638  /// replaces each block with a new block containing the updated signature. The
639  /// entry block may have a special conversion if `entryConversion` is
640  /// provided. On success, the new entry block to the region is returned for
641  /// convenience. Otherwise, failure is returned.
643  Region *region, TypeConverter &converter,
644  TypeConverter::SignatureConversion *entryConversion = nullptr);
645 
646  /// Convert the types of block arguments within the given region except for
647  /// the entry region. This replaces each non-entry block with a new block
648  /// containing the updated signature.
649  ///
650  /// If special conversion behavior is needed for the non-entry blocks (for
651  /// example, we need to convert only a subset of a BB arguments), such
652  /// behavior can be specified in blockConversions.
654  Region *region, TypeConverter &converter,
656 
657  /// Replace all the uses of the block argument `from` with value `to`.
659 
660  /// Return the converted value of 'key' with a type defined by the type
661  /// converter of the currently executing pattern. Return nullptr in the case
662  /// of failure, the remapped value otherwise.
664 
665  /// Return the converted values that replace 'keys' with types defined by the
666  /// type converter of the currently executing pattern. Returns failure if the
667  /// remap failed, success otherwise.
669  SmallVectorImpl<Value> &results);
670 
671  //===--------------------------------------------------------------------===//
672  // PatternRewriter Hooks
673  //===--------------------------------------------------------------------===//
674 
675  /// Indicate that the conversion rewriter can recover from rewrite failure.
676  /// Recovery is supported via rollback, allowing for continued processing of
677  /// patterns even if a failure is encountered during the rewrite step.
678  bool canRecoverFromRewriteFailure() const override { return true; }
679 
680  /// PatternRewriter hook for replacing the results of an operation when the
681  /// given functor returns true.
682  void replaceOpWithIf(
683  Operation *op, ValueRange newValues, bool *allUsesReplaced,
684  llvm::unique_function<bool(OpOperand &) const> functor) override;
685 
686  /// PatternRewriter hook for replacing the results of an operation.
687  void replaceOp(Operation *op, ValueRange newValues) override;
689 
690  /// PatternRewriter hook for erasing a dead operation. The uses of this
691  /// operation *must* be made dead by the end of the conversion process,
692  /// otherwise an assert will be issued.
693  void eraseOp(Operation *op) override;
694 
695  /// PatternRewriter hook for erase all operations in a block. This is not yet
696  /// implemented for dialect conversion.
697  void eraseBlock(Block *block) override;
698 
699  /// PatternRewriter hook creating a new block.
700  void notifyBlockCreated(Block *block) override;
701 
702  /// PatternRewriter hook for splitting a block into two parts.
703  Block *splitBlock(Block *block, Block::iterator before) override;
704 
705  /// PatternRewriter hook for inlining the ops of a block into another block.
706  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
707  ValueRange argValues = std::nullopt) override;
709 
710  /// PatternRewriter hook for moving blocks out of a region.
711  void inlineRegionBefore(Region &region, Region &parent,
712  Region::iterator before) override;
714 
715  /// PatternRewriter hook for cloning blocks of one region into another. The
716  /// given region to clone *must* not have been modified as part of conversion
717  /// yet, i.e. it must be within an operation that is either in the process of
718  /// conversion, or has not yet been converted.
719  void cloneRegionBefore(Region &region, Region &parent,
720  Region::iterator before, IRMapping &mapping) override;
722 
723  /// PatternRewriter hook for inserting a new operation.
724  void notifyOperationInserted(Operation *op) override;
725 
726  /// PatternRewriter hook for updating the root operation in-place.
727  /// Note: These methods only track updates to the top-level operation itself,
728  /// and not nested regions. Updates to regions will still require notification
729  /// through other more specific hooks above.
730  void startRootUpdate(Operation *op) override;
731 
732  /// PatternRewriter hook for updating the root operation in-place.
733  void finalizeRootUpdate(Operation *op) override;
734 
735  /// PatternRewriter hook for updating the root operation in-place.
736  void cancelRootUpdate(Operation *op) override;
737 
738  /// PatternRewriter hook for notifying match failure reasons.
741  function_ref<void(Diagnostic &)> reasonCallback) override;
743 
744  /// Return a reference to the internal implementation.
746 
747 private:
750 
751  std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
752 };
753 
754 //===----------------------------------------------------------------------===//
755 // ConversionTarget
756 //===----------------------------------------------------------------------===//
757 
758 /// This class describes a specific conversion target.
760 public:
761  /// This enumeration corresponds to the specific action to take when
762  /// considering an operation legal for this conversion target.
763  enum class LegalizationAction {
764  /// The target supports this operation.
765  Legal,
766 
767  /// This operation has dynamic legalization constraints that must be checked
768  /// by the target.
769  Dynamic,
770 
771  /// The target explicitly does not support this operation.
772  Illegal,
773  };
774 
775  /// A structure containing additional information describing a specific legal
776  /// operation instance.
777  struct LegalOpDetails {
778  /// A flag that indicates if this operation is 'recursively' legal. This
779  /// means that if an operation is legal, either statically or dynamically,
780  /// all of the operations nested within are also considered legal.
781  bool isRecursivelyLegal = false;
782  };
783 
784  /// The signature of the callback used to determine if an operation is
785  /// dynamically legal on the target.
787  std::function<std::optional<bool>(Operation *)>;
788 
789  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
790  virtual ~ConversionTarget() = default;
791 
792  //===--------------------------------------------------------------------===//
793  // Legality Registration
794  //===--------------------------------------------------------------------===//
795 
796  /// Register a legality action for the given operation.
798  template <typename OpT>
800  setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
801  }
802 
803  /// Register the given operations as legal.
806  }
807  template <typename OpT>
808  void addLegalOp() {
809  addLegalOp(OperationName(OpT::getOperationName(), &ctx));
810  }
811  template <typename OpT, typename OpT2, typename... OpTs>
812  void addLegalOp() {
813  addLegalOp<OpT>();
814  addLegalOp<OpT2, OpTs...>();
815  }
816 
817  /// Register the given operation as dynamically legal and set the dynamic
818  /// legalization callback to the one provided.
820  const DynamicLegalityCallbackFn &callback) {
822  setLegalityCallback(op, callback);
823  }
824  template <typename OpT>
826  addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
827  callback);
828  }
829  template <typename OpT, typename OpT2, typename... OpTs>
831  addDynamicallyLegalOp<OpT>(callback);
832  addDynamicallyLegalOp<OpT2, OpTs...>(callback);
833  }
834  template <typename OpT, class Callable>
835  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
836  addDynamicallyLegalOp(Callable &&callback) {
837  addDynamicallyLegalOp<OpT>(
838  [=](Operation *op) { return callback(cast<OpT>(op)); });
839  }
840 
841  /// Register the given operation as illegal, i.e. this operation is known to
842  /// not be supported by this target.
845  }
846  template <typename OpT>
847  void addIllegalOp() {
848  addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
849  }
850  template <typename OpT, typename OpT2, typename... OpTs>
851  void addIllegalOp() {
852  addIllegalOp<OpT>();
853  addIllegalOp<OpT2, OpTs...>();
854  }
855 
856  /// Mark an operation, that *must* have either been set as `Legal` or
857  /// `DynamicallyLegal`, as being recursively legal. This means that in
858  /// addition to the operation itself, all of the operations nested within are
859  /// also considered legal. An optional dynamic legality callback may be
860  /// provided to mark subsets of legal instances as recursively legal.
862  const DynamicLegalityCallbackFn &callback);
863  template <typename OpT>
865  OperationName opName(OpT::getOperationName(), &ctx);
866  markOpRecursivelyLegal(opName, callback);
867  }
868  template <typename OpT, typename OpT2, typename... OpTs>
870  markOpRecursivelyLegal<OpT>(callback);
871  markOpRecursivelyLegal<OpT2, OpTs...>(callback);
872  }
873  template <typename OpT, class Callable>
874  std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
875  markOpRecursivelyLegal(Callable &&callback) {
876  markOpRecursivelyLegal<OpT>(
877  [=](Operation *op) { return callback(cast<OpT>(op)); });
878  }
879 
880  /// Register a legality action for the given dialects.
881  void setDialectAction(ArrayRef<StringRef> dialectNames,
882  LegalizationAction action);
883 
884  /// Register the operations of the given dialects as legal.
885  template <typename... Names>
886  void addLegalDialect(StringRef name, Names... names) {
887  SmallVector<StringRef, 2> dialectNames({name, names...});
889  }
890  template <typename... Args>
892  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
894  }
895 
896  /// Register the operations of the given dialects as dynamically legal, i.e.
897  /// requiring custom handling by the callback.
898  template <typename... Names>
900  StringRef name, Names... names) {
901  SmallVector<StringRef, 2> dialectNames({name, names...});
903  setLegalityCallback(dialectNames, callback);
904  }
905  template <typename... Args>
907  addDynamicallyLegalDialect(std::move(callback),
908  Args::getDialectNamespace()...);
909  }
910 
911  /// Register unknown operations as dynamically legal. For operations(and
912  /// dialects) that do not have a set legalization action, treat them as
913  /// dynamically legal and invoke the given callback.
915  setLegalityCallback(fn);
916  }
917 
918  /// Register the operations of the given dialects as illegal, i.e.
919  /// operations of this dialect are not supported by the target.
920  template <typename... Names>
921  void addIllegalDialect(StringRef name, Names... names) {
922  SmallVector<StringRef, 2> dialectNames({name, names...});
924  }
925  template <typename... Args>
927  SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
929  }
930 
931  //===--------------------------------------------------------------------===//
932  // Legality Querying
933  //===--------------------------------------------------------------------===//
934 
935  /// Get the legality action for the given operation.
936  std::optional<LegalizationAction> getOpAction(OperationName op) const;
937 
938  /// If the given operation instance is legal on this target, a structure
939  /// containing legality information is returned. If the operation is not
940  /// legal, std::nullopt is returned. Also returns std::nullopt if operation
941  /// legality wasn't registered by user or dynamic legality callbacks returned
942  /// None.
943  ///
944  /// Note: Legality is actually a 4-state: Legal(recursive=true),
945  /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
946  /// either as Legal or Illegal depending on context.
947  std::optional<LegalOpDetails> isLegal(Operation *op) const;
948 
949  /// Returns true is operation instance is illegal on this target. Returns
950  /// false if operation is legal, operation legality wasn't registered by user
951  /// or dynamic legality callbacks returned None.
952  bool isIllegal(Operation *op) const;
953 
954 private:
955  /// Set the dynamic legality callback for the given operation.
956  void setLegalityCallback(OperationName name,
957  const DynamicLegalityCallbackFn &callback);
958 
959  /// Set the dynamic legality callback for the given dialects.
960  void setLegalityCallback(ArrayRef<StringRef> dialects,
961  const DynamicLegalityCallbackFn &callback);
962 
963  /// Set the dynamic legality callback for the unknown ops.
964  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
965 
966  /// The set of information that configures the legalization of an operation.
967  struct LegalizationInfo {
968  /// The legality action this operation was given.
970 
971  /// If some legal instances of this operation may also be recursively legal.
972  bool isRecursivelyLegal = false;
973 
974  /// The legality callback if this operation is dynamically legal.
975  DynamicLegalityCallbackFn legalityFn;
976  };
977 
978  /// Get the legalization information for the given operation.
979  std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
980 
981  /// A deterministic mapping of operation name and its respective legality
982  /// information.
983  llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
984 
985  /// A set of legality callbacks for given operation names that are used to
986  /// check if an operation instance is recursively legal.
987  DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
988 
989  /// A deterministic mapping of dialect name to the specific legality action to
990  /// take.
991  llvm::StringMap<LegalizationAction> legalDialects;
992 
993  /// A set of dynamic legality callbacks for given dialect names.
994  llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
995 
996  /// An optional legality callback for unknown operations.
997  DynamicLegalityCallbackFn unknownLegalityFn;
998 
999  /// The current context this target applies to.
1000  MLIRContext &ctx;
1001 };
1002 
1003 //===----------------------------------------------------------------------===//
1004 // PDL Configuration
1005 //===----------------------------------------------------------------------===//
1006 
1007 /// A PDL configuration that is used to supported dialect conversion
1008 /// functionality.
1010  : public PDLPatternConfigBase<PDLConversionConfig> {
1011 public:
1012  PDLConversionConfig(TypeConverter *converter) : converter(converter) {}
1013  ~PDLConversionConfig() final = default;
1014 
1015  /// Return the type converter used by this configuration, which may be nullptr
1016  /// if no type conversions are expected.
1017  TypeConverter *getTypeConverter() const { return converter; }
1018 
1019  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1020  /// pattern.
1021  void notifyRewriteBegin(PatternRewriter &rewriter) final;
1022  void notifyRewriteEnd(PatternRewriter &rewriter) final;
1023 
1024 private:
1025  /// An optional type converter to use for the pattern.
1026  TypeConverter *converter;
1027 };
1028 
1029 /// Register the dialect conversion PDL functions with the given pattern set.
1030 void registerConversionPDLFunctions(RewritePatternSet &patterns);
1031 
1032 //===----------------------------------------------------------------------===//
1033 // Op Conversion Entry Points
1034 //===----------------------------------------------------------------------===//
1035 
1036 /// Below we define several entry points for operation conversion. It is
1037 /// important to note that the patterns provided to the conversion framework may
1038 /// have additional constraints. See the `PatternRewriter Hooks` section of the
1039 /// ConversionPatternRewriter, to see what additional constraints are imposed on
1040 /// the use of the PatternRewriter.
1041 
1042 /// Apply a partial conversion on the given operations and all nested
1043 /// operations. This method converts as many operations to the target as
1044 /// possible, ignoring operations that failed to legalize. This method only
1045 /// returns failure if there ops explicitly marked as illegal. If an
1046 /// `unconvertedOps` set is provided, all operations that are found not to be
1047 /// legalizable to the given `target` are placed within that set. (Note that if
1048 /// there is an op explicitly marked as illegal, the conversion terminates and
1049 /// the `unconvertedOps` set will not necessarily be complete.)
1050 LogicalResult
1051 applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
1052  const FrozenRewritePatternSet &patterns,
1053  DenseSet<Operation *> *unconvertedOps = nullptr);
1054 LogicalResult
1055 applyPartialConversion(Operation *op, ConversionTarget &target,
1056  const FrozenRewritePatternSet &patterns,
1057  DenseSet<Operation *> *unconvertedOps = nullptr);
1058 
1059 /// Apply a complete conversion on the given operations, and all nested
1060 /// operations. This method returns failure if the conversion of any operation
1061 /// fails, or if there are unreachable blocks in any of the regions nested
1062 /// within 'ops'.
1063 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1064  ConversionTarget &target,
1065  const FrozenRewritePatternSet &patterns);
1066 LogicalResult applyFullConversion(Operation *op, ConversionTarget &target,
1067  const FrozenRewritePatternSet &patterns);
1068 
1069 /// Apply an analysis conversion on the given operations, and all nested
1070 /// operations. This method analyzes which operations would be successfully
1071 /// converted to the target if a conversion was applied. All operations that
1072 /// were found to be legalizable to the given 'target' are placed within the
1073 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
1074 /// operations on success and only pre-existing operations are added to the set.
1075 /// This method only returns failure if there are unreachable blocks in any of
1076 /// the regions nested within 'ops'. There's an additional argument
1077 /// `notifyCallback` which is used for collecting match failure diagnostics
1078 /// generated during the conversion. Diagnostics are only reported to this
1079 /// callback may only be available in debug mode.
1080 LogicalResult applyAnalysisConversion(
1081  ArrayRef<Operation *> ops, ConversionTarget &target,
1082  const FrozenRewritePatternSet &patterns,
1083  DenseSet<Operation *> &convertedOps,
1084  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1085 LogicalResult applyAnalysisConversion(
1086  Operation *op, ConversionTarget &target,
1087  const FrozenRewritePatternSet &patterns,
1088  DenseSet<Operation *> &convertedOps,
1089  function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1090 } // namespace mlir
1091 
1092 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:304
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor) override
PatternRewriter hook for replacing the results of an operation when the given functor returns true.
void notifyBlockCreated(Block *block) override
PatternRewriter hook creating a new block.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void notifyOperationInserted(Operation *op) override
PatternRewriter hook for inserting a new operation.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
bool canRecoverFromRewriteFailure() const override
Indicate that the conversion rewriter can recover from rewrite failure.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping) override
PatternRewriter hook for cloning blocks of one region into another.
Base class for the conversion patterns.
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, ConverterTy * > getTypeConverter() const
TypeConverter * typeConverter
An optional type converter for use by this pattern.
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
ConversionPattern(TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement rewriting.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(LegalizationAction action)
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > markOpRecursivelyLegal(Callable &&callback)
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
std::enable_if_t<!std::is_invocable_v< Callable, Operation * > > addDynamicallyLegalOp(Callable &&callback)
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
@ Illegal
The target explicitly does not support this operation.
@ Dynamic
This operation has dynamic legalization constraints that must be checked by the target.
@ Legal
The target supports this operation.
ConversionTarget(MLIRContext &ctx)
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback={})
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
virtual ~ConversionTarget()=default
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:202
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:297
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:301
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement rewriting.
LogicalResult match(Operation *op) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the ConversionPattern methods that pass the derived op type.
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1)
virtual void rewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
This class represents an operand of an operation.
Definition: Value.h:255
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
A PDL configuration that is used to supported dialect conversion functionality.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
PDLConversionConfig(TypeConverter *converter)
TypeConverter * getTypeConverter() const
Return the type converter used by this configuration, which may be nullptr if no type conversions are...
~PDLConversionConfig() final=default
This class provides a base class for users implementing a type of pattern configuration.
Definition: PatternMatch.h:903
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:72
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType::iterator iterator
Definition: Region.h:52
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:597
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
Type conversion class.
std::enable_if_t<!std::is_convertible< RangeT, Type >::value &&!std::is_convertible< RangeT, Operation * >::value, bool > isLegal(RangeT &&range)
Return true if all of the given types are legal for this type converter.
void addConversion(FnT &&callback)
Register a conversion function.
bool isSignatureLegal(FunctionType ty)
Return true if the inputs and outputs of the given function type are legal.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result)
This method allows for converting a specific argument of a signature.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling 'results' as necessary.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0)
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr)
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block)
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
@ Type
An inlay hint that for a type annotation.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:150
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:163
This struct represents a range of new types or a single value that remaps an existing signature input...