MLIR  20.0.0git
OneToNTypeConversion.h
Go to the documentation of this file.
1 //===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utils for implementing (poor-man's) dialect conversion
10 // passes with 1:N type conversions.
11 //
12 // The main function, `applyPartialOneToNConversion`, first applies a set of
13 // `RewritePattern`s, which produce unrealized casts to convert the operands and
14 // results from and to the source types, and then replaces all newly added
15 // unrealized casts by user-provided materializations. For this to work, the
16 // main function requires a special `TypeConverter`, a special
17 // `PatternRewriter`, and special RewritePattern`s, which extend their
18 // respective base classes for 1:N type converions.
19 //
20 // Note that this is much more simple-minded than the "real" dialect conversion,
21 // which checks for legality before applying patterns and does probably many
22 // other additional things. Ideally, some of the extensions here could be
23 // integrated there.
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
28 #define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
29 
30 #include "mlir/IR/PatternMatch.h"
32 #include "llvm/ADT/SmallVector.h"
33 
34 namespace mlir {
35 
36 /// Stores a 1:N mapping of types and provides several useful accessors. This
37 /// class extends `SignatureConversion`, which already supports 1:N type
38 /// mappings but lacks some accessors into the mapping as well as access to the
39 /// original types.
41 public:
43  : TypeConverter::SignatureConversion(originalTypes.size()),
44  originalTypes(originalTypes) {}
45 
47 
48  /// Returns the list of types that corresponds to the original type at the
49  /// given index.
50  TypeRange getConvertedTypes(unsigned originalTypeNo) const;
51 
52  /// Returns the list of original types.
53  TypeRange getOriginalTypes() const { return originalTypes; }
54 
55  /// Returns the slice of converted values that corresponds the original value
56  /// at the given index.
58  unsigned originalValueNo) const;
59 
60  /// Fills the given result vector with as many copies of the location of the
61  /// original value as the number of values it is converted to.
62  void convertLocation(Value originalValue, unsigned originalValueNo,
63  llvm::SmallVectorImpl<Location> &result) const;
64 
65  /// Fills the given result vector with as many copies of the lociation of each
66  /// original value as the number of values they are respectively converted to.
67  void convertLocations(ValueRange originalValues,
68  llvm::SmallVectorImpl<Location> &result) const;
69 
70  /// Returns true iff at least one type conversion maps an input type to a type
71  /// that is different from itself.
72  bool hasNonIdentityConversion() const;
73 
74 private:
75  llvm::SmallVector<Type> originalTypes;
76 };
77 
78 /// Extends the basic `RewritePattern` class with a type converter member and
79 /// some accessors to it. This is useful for patterns that are not
80 /// `ConversionPattern`s but still require access to a type converter.
82 public:
83  /// Construct a conversion pattern with the given converter, and forward the
84  /// remaining arguments to RewritePattern.
85  template <typename... Args>
87  Args &&...args)
88  : RewritePattern(std::forward<Args>(args)...),
90 
91  /// Return the type converter held by this pattern, or nullptr if the pattern
92  /// does not require type conversion.
93  const TypeConverter *getTypeConverter() const { return typeConverter; }
94 
95  template <typename ConverterTy>
96  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
97  const ConverterTy *>
98  getTypeConverter() const {
99  return static_cast<const ConverterTy *>(typeConverter);
100  }
101 
102 protected:
103  /// A type converter for use by this pattern.
105 };
106 
107 /// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
108 /// class provides additional rewrite methods that are specific to 1:N type
109 /// conversions.
111 public:
113  OpBuilder::Listener *listener = nullptr)
115 
116  /// Replaces the results of the operation with the specified list of values
117  /// mapped back to the original types as specified in the provided type
118  /// mapping. That type mapping must match the replaced op (i.e., the original
119  /// types must be the same as the result types of the op) and the new values
120  /// (i.e., the converted types must be the same as the types of the new
121  /// values).
122  void replaceOp(Operation *op, ValueRange newValues,
123  const OneToNTypeMapping &resultMapping);
125 
126  /// Applies the given argument conversion to the given block. This consists of
127  /// replacing each original argument with N arguments as specified in the
128  /// argument conversion and inserting unrealized casts from the converted
129  /// values to the original types, which are then used in lieu of the original
130  /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
131  /// with a user-provided argument materialization if necessary.) This is
132  /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
133  /// type conversion properly and probably (2) doesn't handle many other edge
134  /// cases.
136  OneToNTypeMapping &argumentConversion);
137 };
138 
139 /// Base class for patterns with 1:N type conversions. Derived classes have to
140 /// overwrite the `matchAndRewrite` overlaod that provides additional
141 /// information for 1:N type conversions.
143 public:
145 
146  /// This function has to be implemented by derived classes and is called from
147  /// the usual overloads. Like in "normal" `DialectConversion`, the function is
148  /// provided with the converted operands (which thus have target types). Since
149  /// 1:N conversions are supported, there is usually no 1:1 relationship
150  /// between the original and the converted operands. Instead, the provided
151  /// `operandMapping` can be used to access the converted operands that
152  /// correspond to a particular original operand. Similarly, `resultMapping`
153  /// is provided to help with assembling the result values, which may have 1:N
154  /// correspondences as well. In that case, the original op should be replaced
155  /// with the overload of `replaceOp` that takes the provided `resultMapping`
156  /// in order to deal with the mapping of converted result values to their
157  /// usages in the original types correctly.
158  virtual LogicalResult matchAndRewrite(Operation *op,
159  OneToNPatternRewriter &rewriter,
160  const OneToNTypeMapping &operandMapping,
161  const OneToNTypeMapping &resultMapping,
162  ValueRange convertedOperands) const = 0;
163 
164  LogicalResult matchAndRewrite(Operation *op,
165  PatternRewriter &rewriter) const final;
166 };
167 
168 /// This class is a wrapper around `OneToNConversionPattern` for matching
169 /// against instances of a particular op class.
170 template <typename SourceOp>
172 public:
174  MLIRContext *context, PatternBenefit benefit = 1,
175  ArrayRef<StringRef> generatedNames = {})
176  : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
177  benefit, context, generatedNames) {}
178  /// Generic adaptor around the root op of this pattern using the converted
179  /// operands. Importantly, each operand is represented as a *range* of values,
180  /// namely the N values each original operand gets converted to. Concretely,
181  /// this makes the result type of the accessor functions of the adaptor class
182  /// be a `ValueRange`.
183  class OpAdaptor
184  : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
185  public:
187  using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
188  using Properties = typename SourceOp::template InferredProperties<SourceOp>;
189 
190  OpAdaptor(const OneToNTypeMapping *operandMapping,
191  const OneToNTypeMapping *resultMapping,
192  const ValueRange *convertedOperands, RangeT values, SourceOp op)
193  : BaseT(values, op), operandMapping(operandMapping),
194  resultMapping(resultMapping), convertedOperands(convertedOperands) {}
195 
196  /// Get the type mapping of the original operands to the converted operands.
198  return *operandMapping;
199  }
200 
201  /// Get the type mapping of the original results to the converted results.
202  const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
203 
204  /// Get a flat range of all converted operands. Unlike `getOperands`, which
205  /// returns an `ArrayRef` with one `ValueRange` for each original operand,
206  /// this function returns a `ValueRange` that contains all converted
207  /// operands irrespectively of which operand they originated from.
208  ValueRange getFlatOperands() const { return *convertedOperands; }
209 
210  private:
211  const OneToNTypeMapping *operandMapping;
212  const OneToNTypeMapping *resultMapping;
213  const ValueRange *convertedOperands;
214  };
215 
217 
218  /// Overload that derived classes have to override for their op type.
219  virtual LogicalResult
220  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
221  OneToNPatternRewriter &rewriter) const = 0;
222 
223  LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
224  const OneToNTypeMapping &operandMapping,
225  const OneToNTypeMapping &resultMapping,
226  ValueRange convertedOperands) const final {
227  // Wrap converted operands and type mappings into an adaptor.
228  SmallVector<ValueRange> valueRanges;
229  for (int64_t i = 0; i < op->getNumOperands(); i++) {
230  auto values = operandMapping.getConvertedValues(convertedOperands, i);
231  valueRanges.push_back(values);
232  }
233  OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
234  valueRanges, cast<SourceOp>(op));
235 
236  // Call overload implemented by the derived class.
237  return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
238  }
239 };
240 
241 /// Applies the given set of patterns recursively on the given op and adds user
242 /// materializations where necessary. The patterns are expected to be
243 /// `OneToNConversionPattern`, which help converting the types of the operands
244 /// and results of the matched ops. The provided type converter is used to
245 /// convert the operands of matched ops from their original types to operands
246 /// with different types. Unlike in `DialectConversion`, this supports 1:N type
247 /// conversions. Those conversions at the "boundary" of the pattern application,
248 /// where converted results are not consumed by replaced ops that expect the
249 /// converted operands or vice versa, the function inserts user materializations
250 /// from the type converter. Also unlike `DialectConversion`, there are no legal
251 /// or illegal types; the function simply applies the given patterns and does
252 /// not fail if some ops or types remain unconverted (i.e., the conversion is
253 /// only "partial").
254 LogicalResult
255 applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
256  const FrozenRewritePatternSet &patterns);
257 
258 /// Add a pattern to the given pattern list to convert the signature of a
259 /// FunctionOpInterface op with the given type converter. This only supports
260 /// ops which use FunctionType to represent their type. This is intended to be
261 /// used with the 1:N dialect conversion.
263  StringRef functionLikeOpName, const TypeConverter &converter,
264  RewritePatternSet &patterns);
265 template <typename FuncOpT>
267  const TypeConverter &converter, RewritePatternSet &patterns) {
269  FuncOpT::getOperationName(), converter, patterns);
270 }
271 
272 } // namespace mlir
273 
274 #endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
Block represents an ordered list of Operations.
Definition: Block.h:33
MLIRContext * context
Definition: Builders.h:211
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Base class for patterns with 1:N type conversions.
virtual LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const =0
This function has to be implemented by derived classes and is called from the usual overloads.
Generic adaptor around the root op of this pattern using the converted operands.
const OneToNTypeMapping & getOperandMapping() const
Get the type mapping of the original operands to the converted operands.
ValueRange getFlatOperands() const
Get a flat range of all converted operands.
typename SourceOp::template InferredProperties< SourceOp > Properties
OpAdaptor(const OneToNTypeMapping *operandMapping, const OneToNTypeMapping *resultMapping, const ValueRange *convertedOperands, RangeT values, SourceOp op)
const OneToNTypeMapping & getResultMapping() const
Get the type mapping of the original results to the converted results.
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
virtual LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const=0
This function has to be implemented by derived classes and is called from the usual overloads.
OneToNOpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const =0
Overload that derived classes have to override for their op type.
LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const final
This function has to be implemented by derived classes and is called from the usual overloads.
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
OneToNPatternRewriter(MLIRContext *context, OpBuilder::Listener *listener=nullptr)
Stores a 1:N mapping of types and provides several useful accessors.
OneToNTypeMapping(TypeRange originalTypes)
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
bool hasNonIdentityConversion() const
Returns true iff at least one type conversion maps an input type to a type that is different from its...
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
void convertLocation(Value originalValue, unsigned originalValueNo, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the location of the original value as the number...
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:616
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
Extends the basic RewritePattern class with a type converter member and some accessors to it.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
const TypeConverter *const typeConverter
A type converter for use by this pattern.
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
RewritePatternWithConverter(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides all of the information necessary to convert a type signature.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void populateOneToNFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, const TypeConverter &converter, RewritePatternSet &patterns)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:294