MLIR  20.0.0git
OneToNTypeConversion.h
Go to the documentation of this file.
1 //===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utils for implementing (poor-man's) dialect conversion
10 // passes with 1:N type conversions.
11 //
12 // The main function, `applyPartialOneToNConversion`, first applies a set of
13 // `RewritePattern`s, which produce unrealized casts to convert the operands and
14 // results from and to the source types, and then replaces all newly added
15 // unrealized casts by user-provided materializations. For this to work, the
16 // main function requires a special `TypeConverter`, a special
17 // `PatternRewriter`, and special RewritePattern`s, which extend their
18 // respective base classes for 1:N type converions.
19 //
20 // Note that this is much more simple-minded than the "real" dialect conversion,
21 // which checks for legality before applying patterns and does probably many
22 // other additional things. Ideally, some of the extensions here could be
23 // integrated there.
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
28 #define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
29 
30 #include "mlir/IR/PatternMatch.h"
32 #include "llvm/ADT/SmallVector.h"
33 
34 namespace mlir {
35 
36 /// Extends `TypeConverter` with 1:N target materializations. Such
37 /// materializations have to provide the "reverse" of 1:N type conversions,
38 /// i.e., they need to materialize N values with target types into one value
39 /// with a source type (which isn't possible in the base class currently).
41 public:
42  /// Callback that expresses user-provided materialization logic from the given
43  /// value to N values of the given types. This is useful for expressing target
44  /// materializations for 1:N type conversions, which materialize one value in
45  /// a source type as N values in target types.
47  std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
48  Value, Location)>;
49 
50  /// Creates the mapping of the given range of original types to target types
51  /// of the conversion and stores that mapping in the given (signature)
52  /// conversion. This function simply calls
53  /// `TypeConverter::convertSignatureArgs` and exists here with a different
54  /// name to reflect the broader semantic.
55  LogicalResult computeTypeMapping(TypeRange types,
56  SignatureConversion &result) const {
57  return convertSignatureArgs(types, result);
58  }
59 
60  /// Applies one of the user-provided 1:N target materializations. If several
61  /// exists, they are tried out in the reverse order in which they have been
62  /// added until the first one succeeds. If none succeeds, the functions
63  /// returns `std::nullopt`.
64  std::optional<SmallVector<Value>>
66  TypeRange resultTypes, Value input) const;
67 
68  /// Adds a 1:N target materialization to the converter. Such materializations
69  /// build IR that converts N values with target types into 1 value of the
70  /// source type.
72  oneToNTargetMaterializations.emplace_back(std::move(callback));
73  }
74 
75 private:
76  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
77 };
78 
79 /// Stores a 1:N mapping of types and provides several useful accessors. This
80 /// class extends `SignatureConversion`, which already supports 1:N type
81 /// mappings but lacks some accessors into the mapping as well as access to the
82 /// original types.
84 public:
86  : TypeConverter::SignatureConversion(originalTypes.size()),
87  originalTypes(originalTypes) {}
88 
90 
91  /// Returns the list of types that corresponds to the original type at the
92  /// given index.
93  TypeRange getConvertedTypes(unsigned originalTypeNo) const;
94 
95  /// Returns the list of original types.
96  TypeRange getOriginalTypes() const { return originalTypes; }
97 
98  /// Returns the slice of converted values that corresponds the original value
99  /// at the given index.
100  ValueRange getConvertedValues(ValueRange convertedValues,
101  unsigned originalValueNo) const;
102 
103  /// Fills the given result vector with as many copies of the location of the
104  /// original value as the number of values it is converted to.
105  void convertLocation(Value originalValue, unsigned originalValueNo,
106  llvm::SmallVectorImpl<Location> &result) const;
107 
108  /// Fills the given result vector with as many copies of the lociation of each
109  /// original value as the number of values they are respectively converted to.
110  void convertLocations(ValueRange originalValues,
111  llvm::SmallVectorImpl<Location> &result) const;
112 
113  /// Returns true iff at least one type conversion maps an input type to a type
114  /// that is different from itself.
115  bool hasNonIdentityConversion() const;
116 
117 private:
118  llvm::SmallVector<Type> originalTypes;
119 };
120 
121 /// Extends the basic `RewritePattern` class with a type converter member and
122 /// some accessors to it. This is useful for patterns that are not
123 /// `ConversionPattern`s but still require access to a type converter.
125 public:
126  /// Construct a conversion pattern with the given converter, and forward the
127  /// remaining arguments to RewritePattern.
128  template <typename... Args>
130  Args &&...args)
131  : RewritePattern(std::forward<Args>(args)...),
133 
134  /// Return the type converter held by this pattern, or nullptr if the pattern
135  /// does not require type conversion.
136  const TypeConverter *getTypeConverter() const { return typeConverter; }
137 
138  template <typename ConverterTy>
139  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
140  const ConverterTy *>
142  return static_cast<const ConverterTy *>(typeConverter);
143  }
144 
145 protected:
146  /// A type converter for use by this pattern.
148 };
149 
150 /// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
151 /// class provides additional rewrite methods that are specific to 1:N type
152 /// conversions.
154 public:
156  OpBuilder::Listener *listener = nullptr)
158 
159  /// Replaces the results of the operation with the specified list of values
160  /// mapped back to the original types as specified in the provided type
161  /// mapping. That type mapping must match the replaced op (i.e., the original
162  /// types must be the same as the result types of the op) and the new values
163  /// (i.e., the converted types must be the same as the types of the new
164  /// values).
165  void replaceOp(Operation *op, ValueRange newValues,
166  const OneToNTypeMapping &resultMapping);
168 
169  /// Applies the given argument conversion to the given block. This consists of
170  /// replacing each original argument with N arguments as specified in the
171  /// argument conversion and inserting unrealized casts from the converted
172  /// values to the original types, which are then used in lieu of the original
173  /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
174  /// with a user-provided argument materialization if necessary.) This is
175  /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
176  /// type conversion properly and probably (2) doesn't handle many other edge
177  /// cases.
179  OneToNTypeMapping &argumentConversion);
180 };
181 
182 /// Base class for patterns with 1:N type conversions. Derived classes have to
183 /// overwrite the `matchAndRewrite` overlaod that provides additional
184 /// information for 1:N type conversions.
186 public:
188 
189  /// This function has to be implemented by derived classes and is called from
190  /// the usual overloads. Like in "normal" `DialectConversion`, the function is
191  /// provided with the converted operands (which thus have target types). Since
192  /// 1:N conversions are supported, there is usually no 1:1 relationship
193  /// between the original and the converted operands. Instead, the provided
194  /// `operandMapping` can be used to access the converted operands that
195  /// correspond to a particular original operand. Similarly, `resultMapping`
196  /// is provided to help with assembling the result values, which may have 1:N
197  /// correspondences as well. In that case, the original op should be replaced
198  /// with the overload of `replaceOp` that takes the provided `resultMapping`
199  /// in order to deal with the mapping of converted result values to their
200  /// usages in the original types correctly.
201  virtual LogicalResult matchAndRewrite(Operation *op,
202  OneToNPatternRewriter &rewriter,
203  const OneToNTypeMapping &operandMapping,
204  const OneToNTypeMapping &resultMapping,
205  ValueRange convertedOperands) const = 0;
206 
207  LogicalResult matchAndRewrite(Operation *op,
208  PatternRewriter &rewriter) const final;
209 };
210 
211 /// This class is a wrapper around `OneToNConversionPattern` for matching
212 /// against instances of a particular op class.
213 template <typename SourceOp>
215 public:
217  MLIRContext *context, PatternBenefit benefit = 1,
218  ArrayRef<StringRef> generatedNames = {})
219  : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
220  benefit, context, generatedNames) {}
221  /// Generic adaptor around the root op of this pattern using the converted
222  /// operands. Importantly, each operand is represented as a *range* of values,
223  /// namely the N values each original operand gets converted to. Concretely,
224  /// this makes the result type of the accessor functions of the adaptor class
225  /// be a `ValueRange`.
226  class OpAdaptor
227  : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
228  public:
230  using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
231  using Properties = typename SourceOp::template InferredProperties<SourceOp>;
232 
233  OpAdaptor(const OneToNTypeMapping *operandMapping,
234  const OneToNTypeMapping *resultMapping,
235  const ValueRange *convertedOperands, RangeT values, SourceOp op)
236  : BaseT(values, op), operandMapping(operandMapping),
237  resultMapping(resultMapping), convertedOperands(convertedOperands) {}
238 
239  /// Get the type mapping of the original operands to the converted operands.
241  return *operandMapping;
242  }
243 
244  /// Get the type mapping of the original results to the converted results.
245  const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
246 
247  /// Get a flat range of all converted operands. Unlike `getOperands`, which
248  /// returns an `ArrayRef` with one `ValueRange` for each original operand,
249  /// this function returns a `ValueRange` that contains all converted
250  /// operands irrespectively of which operand they originated from.
251  ValueRange getFlatOperands() const { return *convertedOperands; }
252 
253  private:
254  const OneToNTypeMapping *operandMapping;
255  const OneToNTypeMapping *resultMapping;
256  const ValueRange *convertedOperands;
257  };
258 
260 
261  /// Overload that derived classes have to override for their op type.
262  virtual LogicalResult
263  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
264  OneToNPatternRewriter &rewriter) const = 0;
265 
266  LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
267  const OneToNTypeMapping &operandMapping,
268  const OneToNTypeMapping &resultMapping,
269  ValueRange convertedOperands) const final {
270  // Wrap converted operands and type mappings into an adaptor.
271  SmallVector<ValueRange> valueRanges;
272  for (int64_t i = 0; i < op->getNumOperands(); i++) {
273  auto values = operandMapping.getConvertedValues(convertedOperands, i);
274  valueRanges.push_back(values);
275  }
276  OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
277  valueRanges, cast<SourceOp>(op));
278 
279  // Call overload implemented by the derived class.
280  return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
281  }
282 };
283 
284 /// Applies the given set of patterns recursively on the given op and adds user
285 /// materializations where necessary. The patterns are expected to be
286 /// `OneToNConversionPattern`, which help converting the types of the operands
287 /// and results of the matched ops. The provided type converter is used to
288 /// convert the operands of matched ops from their original types to operands
289 /// with different types. Unlike in `DialectConversion`, this supports 1:N type
290 /// conversions. Those conversions at the "boundary" of the pattern application,
291 /// where converted results are not consumed by replaced ops that expect the
292 /// converted operands or vice versa, the function inserts user materializations
293 /// from the type converter. Also unlike `DialectConversion`, there are no legal
294 /// or illegal types; the function simply applies the given patterns and does
295 /// not fail if some ops or types remain unconverted (i.e., the conversion is
296 /// only "partial").
297 LogicalResult
298 applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
299  const FrozenRewritePatternSet &patterns);
300 
301 /// Add a pattern to the given pattern list to convert the signature of a
302 /// FunctionOpInterface op with the given type converter. This only supports
303 /// ops which use FunctionType to represent their type. This is intended to be
304 /// used with the 1:N dialect conversion.
306  StringRef functionLikeOpName, const TypeConverter &converter,
307  RewritePatternSet &patterns);
308 template <typename FuncOpT>
310  const TypeConverter &converter, RewritePatternSet &patterns) {
312  FuncOpT::getOperationName(), converter, patterns);
313 }
314 
315 } // namespace mlir
316 
317 #endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
Block represents an ordered list of Operations.
Definition: Block.h:31
MLIRContext * context
Definition: Builders.h:210
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Base class for patterns with 1:N type conversions.
virtual LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const =0
This function has to be implemented by derived classes and is called from the usual overloads.
Generic adaptor around the root op of this pattern using the converted operands.
const OneToNTypeMapping & getOperandMapping() const
Get the type mapping of the original operands to the converted operands.
ValueRange getFlatOperands() const
Get a flat range of all converted operands.
typename SourceOp::template InferredProperties< SourceOp > Properties
OpAdaptor(const OneToNTypeMapping *operandMapping, const OneToNTypeMapping *resultMapping, const ValueRange *convertedOperands, RangeT values, SourceOp op)
const OneToNTypeMapping & getResultMapping() const
Get the type mapping of the original results to the converted results.
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
virtual LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const=0
This function has to be implemented by derived classes and is called from the usual overloads.
OneToNOpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const =0
Overload that derived classes have to override for their op type.
LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const final
This function has to be implemented by derived classes and is called from the usual overloads.
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
OneToNPatternRewriter(MLIRContext *context, OpBuilder::Listener *listener=nullptr)
Extends TypeConverter with 1:N target materializations.
std::optional< SmallVector< Value > > materializeTargetConversion(OpBuilder &builder, Location loc, TypeRange resultTypes, Value input) const
Applies one of the user-provided 1:N target materializations.
void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback)
Adds a 1:N target materialization to the converter.
std::function< std::optional< SmallVector< Value > >(OpBuilder &, TypeRange, Value, Location)> OneToNMaterializationCallbackFn
Callback that expresses user-provided materialization logic from the given value to N values of the g...
LogicalResult computeTypeMapping(TypeRange types, SignatureConversion &result) const
Creates the mapping of the given range of original types to target types of the conversion and stores...
Stores a 1:N mapping of types and provides several useful accessors.
OneToNTypeMapping(TypeRange originalTypes)
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
bool hasNonIdentityConversion() const
Returns true iff at least one type conversion maps an input type to a type that is different from its...
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
void convertLocation(Value originalValue, unsigned originalValueNo, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the location of the original value as the number...
This class helps build Operations.
Definition: Builders.h:215
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:615
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Extends the basic RewritePattern class with a type converter member and some accessors to it.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
const TypeConverter *const typeConverter
A type converter for use by this pattern.
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
RewritePatternWithConverter(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides all of the information necessary to convert a type signature.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void populateOneToNFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, const TypeConverter &converter, RewritePatternSet &patterns)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:293