1 //===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utils for implementing (poor-man's) dialect conversion
10 // passes with 1:N type conversions.
11 //
12 // The main function, `applyPartialOneToNConversion`, first applies a set of
13 // `RewritePattern`s, which produce unrealized casts to convert the operands and
14 // results from and to the source types, and then replaces all newly added
15 // unrealized casts by user-provided materializations. For this to work, the
16 // main function requires a special `TypeConverter`, a special
17 // `PatternRewriter`, and special RewritePattern`s, which extend their
18 // respective base classes for 1:N type converions.
19 //
20 // Note that this is much more simple-minded than the "real" dialect conversion,
21 // which checks for legality before applying patterns and does probably many
22 // other additional things. Ideally, some of the extensions here could be
23 // integrated there.
24 //
25 //===----------------------------------------------------------------------===//
30 #include "mlir/IR/PatternMatch.h"
32 #include "llvm/ADT/SmallVector.h"
34 namespace mlir {
36 /// Extends `TypeConverter` with 1:N target materializations. Such
37 /// materializations have to provide the "reverse" of 1:N type conversions,
38 /// i.e., they need to materialize N values with target types into one value
39 /// with a source type (which isn't possible in the base class currently).
41 public:
42  /// Callback that expresses user-provided materialization logic from the given
43  /// value to N values of the given types. This is useful for expressing target
44  /// materializations for 1:N type conversions, which materialize one value in
45  /// a source type as N values in target types.
47  std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
48  Value, Location)>;
50  /// Creates the mapping of the given range of original types to target types
51  /// of the conversion and stores that mapping in the given (signature)
52  /// conversion. This function simply calls
53  /// `TypeConverter::convertSignatureArgs` and exists here with a different
54  /// name to reflect the broader semantic.
55  LogicalResult computeTypeMapping(TypeRange types,
56  SignatureConversion &result) const {
57  return convertSignatureArgs(types, result);
58  }
60  /// Applies one of the user-provided 1:N target materializations. If several
61  /// exists, they are tried out in the reverse order in which they have been
62  /// added until the first one succeeds. If none succeeds, the functions
63  /// returns `std::nullopt`.
64  std::optional<SmallVector<Value>>
66  TypeRange resultTypes, Value input) const;
68  /// Adds a 1:N target materialization to the converter. Such materializations
69  /// build IR that converts N values with target types into 1 value of the
70  /// source type.
72  oneToNTargetMaterializations.emplace_back(std::move(callback));
73  }
75 private:
76  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
77 };
79 /// Stores a 1:N mapping of types and provides several useful accessors. This
80 /// class extends `SignatureConversion`, which already supports 1:N type
81 /// mappings but lacks some accessors into the mapping as well as access to the
82 /// original types.
84 public:
86  : TypeConverter::SignatureConversion(originalTypes.size()),
87  originalTypes(originalTypes) {}
91  /// Returns the list of types that corresponds to the original type at the
92  /// given index.
93  TypeRange getConvertedTypes(unsigned originalTypeNo) const;
95  /// Returns the list of original types.
96  TypeRange getOriginalTypes() const { return originalTypes; }
98  /// Returns the slice of converted values that corresponds the original value
99  /// at the given index.
100  ValueRange getConvertedValues(ValueRange convertedValues,
101  unsigned originalValueNo) const;
103  /// Fills the given result vector with as many copies of the location of the
104  /// original value as the number of values it is converted to.
105  void convertLocation(Value originalValue, unsigned originalValueNo,
106  llvm::SmallVectorImpl<Location> &result) const;
108  /// Fills the given result vector with as many copies of the lociation of each
109  /// original value as the number of values they are respectively converted to.
110  void convertLocations(ValueRange originalValues,
111  llvm::SmallVectorImpl<Location> &result) const;
113  /// Returns true iff at least one type conversion maps an input type to a type
114  /// that is different from itself.
115  bool hasNonIdentityConversion() const;
117 private:
118  llvm::SmallVector<Type> originalTypes;
119 };
121 /// Extends the basic `RewritePattern` class with a type converter member and
122 /// some accessors to it. This is useful for patterns that are not
123 /// `ConversionPattern`s but still require access to a type converter.
125 public:
126  /// Construct a conversion pattern with the given converter, and forward the
127  /// remaining arguments to RewritePattern.
128  template <typename... Args>
130  Args &&...args)
131  : RewritePattern(std::forward<Args>(args)...),
134  /// Return the type converter held by this pattern, or nullptr if the pattern
135  /// does not require type conversion.
136  const TypeConverter *getTypeConverter() const { return typeConverter; }
138  template <typename ConverterTy>
139  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
140  const ConverterTy *>
142  return static_cast<const ConverterTy *>(typeConverter);
143  }
145 protected:
146  /// A type converter for use by this pattern.
148 };
150 /// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
151 /// class provides additional rewrite methods that are specific to 1:N type
152 /// conversions.
154 public:
156  OpBuilder::Listener *listener = nullptr)
159  /// Replaces the results of the operation with the specified list of values
160  /// mapped back to the original types as specified in the provided type
161  /// mapping. That type mapping must match the replaced op (i.e., the original
162  /// types must be the same as the result types of the op) and the new values
163  /// (i.e., the converted types must be the same as the types of the new
164  /// values).
165  void replaceOp(Operation *op, ValueRange newValues,
166  const OneToNTypeMapping &resultMapping);
169  /// Applies the given argument conversion to the given block. This consists of
170  /// replacing each original argument with N arguments as specified in the
171  /// argument conversion and inserting unrealized casts from the converted
172  /// values to the original types, which are then used in lieu of the original
173  /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
174  /// with a user-provided argument materialization if necessary.) This is
175  /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
176  /// type conversion properly and probably (2) doesn't handle many other edge
177  /// cases.
179  OneToNTypeMapping &argumentConversion);
180 };
182 /// Base class for patterns with 1:N type conversions. Derived classes have to
183 /// overwrite the `matchAndRewrite` overlaod that provides additional
184 /// information for 1:N type conversions.
186 public:
189  /// This function has to be implemented by derived classes and is called from
190  /// the usual overloads. Like in "normal" `DialectConversion`, the function is
191  /// provided with the converted operands (which thus have target types). Since
192  /// 1:N conversions are supported, there is usually no 1:1 relationship
193  /// between the original and the converted operands. Instead, the provided
194  /// `operandMapping` can be used to access the converted operands that
195  /// correspond to a particular original operand. Similarly, `resultMapping`
196  /// is provided to help with assembling the result values, which may have 1:N
197  /// correspondences as well. In that case, the original op should be replaced
198  /// with the overload of `replaceOp` that takes the provided `resultMapping`
199  /// in order to deal with the mapping of converted result values to their
200  /// usages in the original types correctly.
201  virtual LogicalResult matchAndRewrite(Operation *op,
202  OneToNPatternRewriter &rewriter,
203  const OneToNTypeMapping &operandMapping,
204  const OneToNTypeMapping &resultMapping,
205  ValueRange convertedOperands) const = 0;
207  LogicalResult matchAndRewrite(Operation *op,
208  PatternRewriter &rewriter) const final;
209 };
211 /// This class is a wrapper around `OneToNConversionPattern` for matching
212 /// against instances of a particular op class.
213 template <typename SourceOp>
215 public:
217  MLIRContext *context, PatternBenefit benefit = 1,
218  ArrayRef<StringRef> generatedNames = {})
219  : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
220  benefit, context, generatedNames) {}
221  /// Generic adaptor around the root op of this pattern using the converted
222  /// operands. Importantly, each operand is represented as a *range* of values,
223  /// namely the N values each original operand gets converted to. Concretely,
224  /// this makes the result type of the accessor functions of the adaptor class
225  /// be a `ValueRange`.
226  class OpAdaptor
227  : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
228  public:
230  using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
231  using Properties = typename SourceOp::template InferredProperties<SourceOp>;
233  OpAdaptor(const OneToNTypeMapping *operandMapping,
234  const OneToNTypeMapping *resultMapping,
235  const ValueRange *convertedOperands, RangeT values, SourceOp op)
236  : BaseT(values, op), operandMapping(operandMapping),
237  resultMapping(resultMapping), convertedOperands(convertedOperands) {}
239  /// Get the type mapping of the original operands to the converted operands.
241  return *operandMapping;
242  }
244  /// Get the type mapping of the original results to the converted results.
245  const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
247  /// Get a flat range of all converted operands. Unlike `getOperands`, which
248  /// returns an `ArrayRef` with one `ValueRange` for each original operand,
249  /// this function returns a `ValueRange` that contains all converted
250  /// operands irrespectively of which operand they originated from.
251  ValueRange getFlatOperands() const { return *convertedOperands; }
253  private:
254  const OneToNTypeMapping *operandMapping;
255  const OneToNTypeMapping *resultMapping;
256  const ValueRange *convertedOperands;
257  };
261  /// Overload that derived classes have to override for their op type.
262  virtual LogicalResult
263  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
264  OneToNPatternRewriter &rewriter) const = 0;
266  LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
267  const OneToNTypeMapping &operandMapping,
268  const OneToNTypeMapping &resultMapping,
269  ValueRange convertedOperands) const final {
270  // Wrap converted operands and type mappings into an adaptor.
271  SmallVector<ValueRange> valueRanges;
272  for (int64_t i = 0; i < op->getNumOperands(); i++) {
273  auto values = operandMapping.getConvertedValues(convertedOperands, i);
274  valueRanges.push_back(values);
275  }
276  OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
277  valueRanges, cast<SourceOp>(op));
279  // Call overload implemented by the derived class.
280  return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
281  }
282 };
284 /// Applies the given set of patterns recursively on the given op and adds user
285 /// materializations where necessary. The patterns are expected to be
286 /// `OneToNConversionPattern`, which help converting the types of the operands
287 /// and results of the matched ops. The provided type converter is used to
288 /// convert the operands of matched ops from their original types to operands
289 /// with different types. Unlike in `DialectConversion`, this supports 1:N type
290 /// conversions. Those conversions at the "boundary" of the pattern application,
291 /// where converted results are not consumed by replaced ops that expect the
292 /// converted operands or vice versa, the function inserts user materializations
293 /// from the type converter. Also unlike `DialectConversion`, there are no legal
294 /// or illegal types; the function simply applies the given patterns and does
295 /// not fail if some ops or types remain unconverted (i.e., the conversion is
296 /// only "partial").
297 LogicalResult
298 applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
299  const FrozenRewritePatternSet &patterns);
301 /// Add a pattern to the given pattern list to convert the signature of a
302 /// FunctionOpInterface op with the given type converter. This only supports
303 /// ops which use FunctionType to represent their type. This is intended to be
304 /// used with the 1:N dialect conversion.
306  StringRef functionLikeOpName, const TypeConverter &converter,
307  RewritePatternSet &patterns);
308 template <typename FuncOpT>
310  const TypeConverter &converter, RewritePatternSet &patterns) {
312  FuncOpT::getOperationName(), converter, patterns);
313 }
315 } // namespace mlir
