MLIR  20.0.0git
OneToNTypeConversion.h
Go to the documentation of this file.
1 //===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utils for implementing (poor-man's) dialect conversion
10 // passes with 1:N type conversions.
11 //
12 // The main function, `applyPartialOneToNConversion`, first applies a set of
13 // `RewritePattern`s, which produce unrealized casts to convert the operands and
14 // results from and to the source types, and then replaces all newly added
15 // unrealized casts by user-provided materializations. For this to work, the
16 // main function requires a special `TypeConverter`, a special
17 // `PatternRewriter`, and special RewritePattern`s, which extend their
18 // respective base classes for 1:N type converions.
19 //
20 // Note that this is much more simple-minded than the "real" dialect conversion,
21 // which checks for legality before applying patterns and does probably many
22 // other additional things. Ideally, some of the extensions here could be
23 // integrated there.
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
28 #define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
29 
30 #include "mlir/IR/PatternMatch.h"
32 #include "llvm/ADT/SmallVector.h"
33 
34 namespace mlir {
35 
36 /// Extends `TypeConverter` with 1:N target materializations. Such
37 /// materializations have to provide the "reverse" of 1:N type conversions,
38 /// i.e., they need to materialize N values with target types into one value
39 /// with a source type (which isn't possible in the base class currently).
41 public:
42  /// Callback that expresses user-provided materialization logic from the given
43  /// value to N values of the given types. This is useful for expressing target
44  /// materializations for 1:N type conversions, which materialize one value in
45  /// a source type as N values in target types.
47  std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
48  Value, Location)>;
49 
50  /// Creates the mapping of the given range of original types to target types
51  /// of the conversion and stores that mapping in the given (signature)
52  /// conversion. This function simply calls
53  /// `TypeConverter::convertSignatureArgs` and exists here with a different
54  /// name to reflect the broader semantic.
55  LogicalResult computeTypeMapping(TypeRange types,
56  SignatureConversion &result) {
57  return convertSignatureArgs(types, result);
58  }
59 
60  /// Applies one of the user-provided 1:N target materializations. If several
61  /// exists, they are tried out in the reverse order in which they have been
62  /// added until the first one succeeds. If none succeeds, the functions
63  /// returns `std::nullopt`.
64  std::optional<SmallVector<Value>>
66  TypeRange resultTypes, Value input) const;
67 
68  /// Adds a 1:N target materialization to the converter. Such materializations
69  /// build IR that converts N values with target types into 1 value of the
70  /// source type.
72  oneToNTargetMaterializations.emplace_back(std::move(callback));
73  }
74 
75 private:
76  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
77 };
78 
79 /// Stores a 1:N mapping of types and provides several useful accessors. This
80 /// class extends `SignatureConversion`, which already supports 1:N type
81 /// mappings but lacks some accessors into the mapping as well as access to the
82 /// original types.
84 public:
86  : TypeConverter::SignatureConversion(originalTypes.size()),
87  originalTypes(originalTypes) {}
88 
90 
91  /// Returns the list of types that corresponds to the original type at the
92  /// given index.
93  TypeRange getConvertedTypes(unsigned originalTypeNo) const;
94 
95  /// Returns the list of original types.
96  TypeRange getOriginalTypes() const { return originalTypes; }
97 
98  /// Returns the slice of converted values that corresponds the original value
99  /// at the given index.
100  ValueRange getConvertedValues(ValueRange convertedValues,
101  unsigned originalValueNo) const;
102 
103  /// Fills the given result vector with as many copies of the location of the
104  /// original value as the number of values it is converted to.
105  void convertLocation(Value originalValue, unsigned originalValueNo,
106  llvm::SmallVectorImpl<Location> &result) const;
107 
108  /// Fills the given result vector with as many copies of the lociation of each
109  /// original value as the number of values they are respectively converted to.
110  void convertLocations(ValueRange originalValues,
111  llvm::SmallVectorImpl<Location> &result) const;
112 
113  /// Returns true iff at least one type conversion maps an input type to a type
114  /// that is different from itself.
115  bool hasNonIdentityConversion() const;
116 
117 private:
118  llvm::SmallVector<Type> originalTypes;
119 };
120 
121 /// Extends the basic `RewritePattern` class with a type converter member and
122 /// some accessors to it. This is useful for patterns that are not
123 /// `ConversionPattern`s but still require access to a type converter.
125 public:
126  /// Construct a conversion pattern with the given converter, and forward the
127  /// remaining arguments to RewritePattern.
128  template <typename... Args>
130  : RewritePattern(std::forward<Args>(args)...),
132 
133  /// Return the type converter held by this pattern, or nullptr if the pattern
134  /// does not require type conversion.
136 
137  template <typename ConverterTy>
138  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
139  ConverterTy *>
141  return static_cast<ConverterTy *>(typeConverter);
142  }
143 
144 protected:
145  /// A type converter for use by this pattern.
147 };
148 
149 /// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
150 /// class provides additional rewrite methods that are specific to 1:N type
151 /// conversions.
153 public:
155  OpBuilder::Listener *listener = nullptr)
157 
158  /// Replaces the results of the operation with the specified list of values
159  /// mapped back to the original types as specified in the provided type
160  /// mapping. That type mapping must match the replaced op (i.e., the original
161  /// types must be the same as the result types of the op) and the new values
162  /// (i.e., the converted types must be the same as the types of the new
163  /// values).
164  void replaceOp(Operation *op, ValueRange newValues,
165  const OneToNTypeMapping &resultMapping);
167 
168  /// Applies the given argument conversion to the given block. This consists of
169  /// replacing each original argument with N arguments as specified in the
170  /// argument conversion and inserting unrealized casts from the converted
171  /// values to the original types, which are then used in lieu of the original
172  /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
173  /// with a user-provided argument materialization if necessary.) This is
174  /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
175  /// type conversion properly and probably (2) doesn't handle many other edge
176  /// cases.
178  OneToNTypeMapping &argumentConversion);
179 };
180 
181 /// Base class for patterns with 1:N type conversions. Derived classes have to
182 /// overwrite the `matchAndRewrite` overlaod that provides additional
183 /// information for 1:N type conversions.
185 public:
187 
188  /// This function has to be implemented by derived classes and is called from
189  /// the usual overloads. Like in "normal" `DialectConversion`, the function is
190  /// provided with the converted operands (which thus have target types). Since
191  /// 1:N conversions are supported, there is usually no 1:1 relationship
192  /// between the original and the converted operands. Instead, the provided
193  /// `operandMapping` can be used to access the converted operands that
194  /// correspond to a particular original operand. Similarly, `resultMapping`
195  /// is provided to help with assembling the result values, which may have 1:N
196  /// correspondences as well. In that case, the original op should be replaced
197  /// with the overload of `replaceOp` that takes the provided `resultMapping`
198  /// in order to deal with the mapping of converted result values to their
199  /// usages in the original types correctly.
200  virtual LogicalResult matchAndRewrite(Operation *op,
201  OneToNPatternRewriter &rewriter,
202  const OneToNTypeMapping &operandMapping,
203  const OneToNTypeMapping &resultMapping,
204  ValueRange convertedOperands) const = 0;
205 
206  LogicalResult matchAndRewrite(Operation *op,
207  PatternRewriter &rewriter) const final;
208 };
209 
210 /// This class is a wrapper around `OneToNConversionPattern` for matching
211 /// against instances of a particular op class.
212 template <typename SourceOp>
214 public:
216  PatternBenefit benefit = 1,
217  ArrayRef<StringRef> generatedNames = {})
218  : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
219  benefit, context, generatedNames) {}
220  /// Generic adaptor around the root op of this pattern using the converted
221  /// operands. Importantly, each operand is represented as a *range* of values,
222  /// namely the N values each original operand gets converted to. Concretely,
223  /// this makes the result type of the accessor functions of the adaptor class
224  /// be a `ValueRange`.
225  class OpAdaptor
226  : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
227  public:
229  using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
230  using Properties = typename SourceOp::template InferredProperties<SourceOp>;
231 
232  OpAdaptor(const OneToNTypeMapping *operandMapping,
233  const OneToNTypeMapping *resultMapping,
234  const ValueRange *convertedOperands, RangeT values, SourceOp op)
235  : BaseT(values, op), operandMapping(operandMapping),
236  resultMapping(resultMapping), convertedOperands(convertedOperands) {}
237 
238  /// Get the type mapping of the original operands to the converted operands.
240  return *operandMapping;
241  }
242 
243  /// Get the type mapping of the original results to the converted results.
244  const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
245 
246  /// Get a flat range of all converted operands. Unlike `getOperands`, which
247  /// returns an `ArrayRef` with one `ValueRange` for each original operand,
248  /// this function returns a `ValueRange` that contains all converted
249  /// operands irrespectively of which operand they originated from.
250  ValueRange getFlatOperands() const { return *convertedOperands; }
251 
252  private:
253  const OneToNTypeMapping *operandMapping;
254  const OneToNTypeMapping *resultMapping;
255  const ValueRange *convertedOperands;
256  };
257 
259 
260  /// Overload that derived classes have to override for their op type.
261  virtual LogicalResult
262  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
263  OneToNPatternRewriter &rewriter) const = 0;
264 
265  LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
266  const OneToNTypeMapping &operandMapping,
267  const OneToNTypeMapping &resultMapping,
268  ValueRange convertedOperands) const final {
269  // Wrap converted operands and type mappings into an adaptor.
270  SmallVector<ValueRange> valueRanges;
271  for (int64_t i = 0; i < op->getNumOperands(); i++) {
272  auto values = operandMapping.getConvertedValues(convertedOperands, i);
273  valueRanges.push_back(values);
274  }
275  OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
276  valueRanges, cast<SourceOp>(op));
277 
278  // Call overload implemented by the derived class.
279  return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
280  }
281 };
282 
283 /// Applies the given set of patterns recursively on the given op and adds user
284 /// materializations where necessary. The patterns are expected to be
285 /// `OneToNConversionPattern`, which help converting the types of the operands
286 /// and results of the matched ops. The provided type converter is used to
287 /// convert the operands of matched ops from their original types to operands
288 /// with different types. Unlike in `DialectConversion`, this supports 1:N type
289 /// conversions. Those conversions at the "boundary" of the pattern application,
290 /// where converted results are not consumed by replaced ops that expect the
291 /// converted operands or vice versa, the function inserts user materializations
292 /// from the type converter. Also unlike `DialectConversion`, there are no legal
293 /// or illegal types; the function simply applies the given patterns and does
294 /// not fail if some ops or types remain unconverted (i.e., the conversion is
295 /// only "partial").
296 LogicalResult
297 applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
298  const FrozenRewritePatternSet &patterns);
299 
300 /// Add a pattern to the given pattern list to convert the signature of a
301 /// FunctionOpInterface op with the given type converter. This only supports
302 /// ops which use FunctionType to represent their type. This is intended to be
303 /// used with the 1:N dialect conversion.
305  StringRef functionLikeOpName, TypeConverter &converter,
306  RewritePatternSet &patterns);
307 template <typename FuncOpT>
309  TypeConverter &converter, RewritePatternSet &patterns) {
311  FuncOpT::getOperationName(), converter, patterns);
312 }
313 
314 } // namespace mlir
315 
316 #endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
Block represents an ordered list of Operations.
Definition: Block.h:31
MLIRContext * context
Definition: Builders.h:207
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Base class for patterns with 1:N type conversions.
virtual LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const =0
This function has to be implemented by derived classes and is called from the usual overloads.
Generic adaptor around the root op of this pattern using the converted operands.
const OneToNTypeMapping & getOperandMapping() const
Get the type mapping of the original operands to the converted operands.
ValueRange getFlatOperands() const
Get a flat range of all converted operands.
typename SourceOp::template InferredProperties< SourceOp > Properties
OpAdaptor(const OneToNTypeMapping *operandMapping, const OneToNTypeMapping *resultMapping, const ValueRange *convertedOperands, RangeT values, SourceOp op)
const OneToNTypeMapping & getResultMapping() const
Get the type mapping of the original results to the converted results.
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
virtual LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const=0
This function has to be implemented by derived classes and is called from the usual overloads.
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const =0
Overload that derived classes have to override for their op type.
LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const final
This function has to be implemented by derived classes and is called from the usual overloads.
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
OneToNPatternRewriter(MLIRContext *context, OpBuilder::Listener *listener=nullptr)
Extends TypeConverter with 1:N target materializations.
std::optional< SmallVector< Value > > materializeTargetConversion(OpBuilder &builder, Location loc, TypeRange resultTypes, Value input) const
Applies one of the user-provided 1:N target materializations.
LogicalResult computeTypeMapping(TypeRange types, SignatureConversion &result)
Creates the mapping of the given range of original types to target types of the conversion and stores...
void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback)
Adds a 1:N target materialization to the converter.
std::function< std::optional< SmallVector< Value > >(OpBuilder &, TypeRange, Value, Location)> OneToNMaterializationCallbackFn
Callback that expresses user-provided materialization logic from the given value to N values of the g...
Stores a 1:N mapping of types and provides several useful accessors.
OneToNTypeMapping(TypeRange originalTypes)
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
bool hasNonIdentityConversion() const
Returns true iff at least one type conversion maps an input type to a type that is different from its...
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
void convertLocation(Value originalValue, unsigned originalValueNo, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the location of the original value as the number...
This class helps build Operations.
Definition: Builders.h:212
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:612
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Extends the basic RewritePattern class with a type converter member and some accessors to it.
TypeConverter *const typeConverter
A type converter for use by this pattern.
RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, ConverterTy * > getTypeConverter() const
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides all of the information necessary to convert a type signature.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void populateOneToNFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, TypeConverter &converter, RewritePatternSet &patterns)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:290