MLIR  20.0.0git
OneToNTypeConversion.h
Go to the documentation of this file.
1 //===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Note: The 1:N dialect conversion is deprecated and will be removed soon.
10 // 1:N support has been added to the regular dialect conversion driver.
11 //
12 // This file provides utils for implementing (poor-man's) dialect conversion
13 // passes with 1:N type conversions.
14 //
15 // The main function, `applyPartialOneToNConversion`, first applies a set of
16 // `RewritePattern`s, which produce unrealized casts to convert the operands and
17 // results from and to the source types, and then replaces all newly added
18 // unrealized casts by user-provided materializations. For this to work, the
19 // main function requires a special `TypeConverter`, a special
20 // `PatternRewriter`, and special RewritePattern`s, which extend their
21 // respective base classes for 1:N type converions.
22 //
23 // Note that this is much more simple-minded than the "real" dialect conversion,
24 // which checks for legality before applying patterns and does probably many
25 // other additional things. Ideally, some of the extensions here could be
26 // integrated there.
27 //
28 //===----------------------------------------------------------------------===//
29 
30 #ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
31 #define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
32 
33 #include "mlir/IR/PatternMatch.h"
35 #include "llvm/ADT/SmallVector.h"
36 
37 namespace mlir {
38 
39 /// Stores a 1:N mapping of types and provides several useful accessors. This
40 /// class extends `SignatureConversion`, which already supports 1:N type
41 /// mappings but lacks some accessors into the mapping as well as access to the
42 /// original types.
44 public:
46  : TypeConverter::SignatureConversion(originalTypes.size()),
47  originalTypes(originalTypes) {}
48 
50 
51  /// Returns the list of types that corresponds to the original type at the
52  /// given index.
53  TypeRange getConvertedTypes(unsigned originalTypeNo) const;
54 
55  /// Returns the list of original types.
56  TypeRange getOriginalTypes() const { return originalTypes; }
57 
58  /// Returns the slice of converted values that corresponds the original value
59  /// at the given index.
61  unsigned originalValueNo) const;
62 
63  /// Fills the given result vector with as many copies of the location of the
64  /// original value as the number of values it is converted to.
65  void convertLocation(Value originalValue, unsigned originalValueNo,
66  llvm::SmallVectorImpl<Location> &result) const;
67 
68  /// Fills the given result vector with as many copies of the lociation of each
69  /// original value as the number of values they are respectively converted to.
70  void convertLocations(ValueRange originalValues,
71  llvm::SmallVectorImpl<Location> &result) const;
72 
73  /// Returns true iff at least one type conversion maps an input type to a type
74  /// that is different from itself.
75  bool hasNonIdentityConversion() const;
76 
77 private:
78  llvm::SmallVector<Type> originalTypes;
79 };
80 
81 /// Extends the basic `RewritePattern` class with a type converter member and
82 /// some accessors to it. This is useful for patterns that are not
83 /// `ConversionPattern`s but still require access to a type converter.
85 public:
86  /// Construct a conversion pattern with the given converter, and forward the
87  /// remaining arguments to RewritePattern.
88  template <typename... Args>
90  Args &&...args)
91  : RewritePattern(std::forward<Args>(args)...),
93 
94  /// Return the type converter held by this pattern, or nullptr if the pattern
95  /// does not require type conversion.
96  const TypeConverter *getTypeConverter() const { return typeConverter; }
97 
98  template <typename ConverterTy>
99  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
100  const ConverterTy *>
102  return static_cast<const ConverterTy *>(typeConverter);
103  }
104 
105 protected:
106  /// A type converter for use by this pattern.
108 };
109 
110 /// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
111 /// class provides additional rewrite methods that are specific to 1:N type
112 /// conversions.
114 public:
116  OpBuilder::Listener *listener = nullptr)
118 
119  /// Replaces the results of the operation with the specified list of values
120  /// mapped back to the original types as specified in the provided type
121  /// mapping. That type mapping must match the replaced op (i.e., the original
122  /// types must be the same as the result types of the op) and the new values
123  /// (i.e., the converted types must be the same as the types of the new
124  /// values).
125  /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
126  /// Use replaceOpWithMultiple() instead.
127  void replaceOp(Operation *op, ValueRange newValues,
128  const OneToNTypeMapping &resultMapping);
130 
131  /// Applies the given argument conversion to the given block. This consists of
132  /// replacing each original argument with N arguments as specified in the
133  /// argument conversion and inserting unrealized casts from the converted
134  /// values to the original types, which are then used in lieu of the original
135  /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
136  /// with a user-provided argument materialization if necessary.) This is
137  /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
138  /// type conversion properly and probably (2) doesn't handle many other edge
139  /// cases.
141  OneToNTypeMapping &argumentConversion);
142 };
143 
144 /// Base class for patterns with 1:N type conversions. Derived classes have to
145 /// overwrite the `matchAndRewrite` overlaod that provides additional
146 /// information for 1:N type conversions.
148 public:
150 
151  /// This function has to be implemented by derived classes and is called from
152  /// the usual overloads. Like in "normal" `DialectConversion`, the function is
153  /// provided with the converted operands (which thus have target types). Since
154  /// 1:N conversions are supported, there is usually no 1:1 relationship
155  /// between the original and the converted operands. Instead, the provided
156  /// `operandMapping` can be used to access the converted operands that
157  /// correspond to a particular original operand. Similarly, `resultMapping`
158  /// is provided to help with assembling the result values, which may have 1:N
159  /// correspondences as well. In that case, the original op should be replaced
160  /// with the overload of `replaceOp` that takes the provided `resultMapping`
161  /// in order to deal with the mapping of converted result values to their
162  /// usages in the original types correctly.
163  virtual LogicalResult matchAndRewrite(Operation *op,
164  OneToNPatternRewriter &rewriter,
165  const OneToNTypeMapping &operandMapping,
166  const OneToNTypeMapping &resultMapping,
167  ValueRange convertedOperands) const = 0;
168 
169  LogicalResult matchAndRewrite(Operation *op,
170  PatternRewriter &rewriter) const final;
171 };
172 
173 /// This class is a wrapper around `OneToNConversionPattern` for matching
174 /// against instances of a particular op class.
175 template <typename SourceOp>
177 public:
179  MLIRContext *context, PatternBenefit benefit = 1,
180  ArrayRef<StringRef> generatedNames = {})
181  : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
182  benefit, context, generatedNames) {}
183  /// Generic adaptor around the root op of this pattern using the converted
184  /// operands. Importantly, each operand is represented as a *range* of values,
185  /// namely the N values each original operand gets converted to. Concretely,
186  /// this makes the result type of the accessor functions of the adaptor class
187  /// be a `ValueRange`.
188  class OpAdaptor
189  : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
190  public:
192  using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
193  using Properties = typename SourceOp::template InferredProperties<SourceOp>;
194 
195  OpAdaptor(const OneToNTypeMapping *operandMapping,
196  const OneToNTypeMapping *resultMapping,
197  const ValueRange *convertedOperands, RangeT values, SourceOp op)
198  : BaseT(values, op), operandMapping(operandMapping),
199  resultMapping(resultMapping), convertedOperands(convertedOperands) {}
200 
201  /// Get the type mapping of the original operands to the converted operands.
203  return *operandMapping;
204  }
205 
206  /// Get the type mapping of the original results to the converted results.
207  const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
208 
209  /// Get a flat range of all converted operands. Unlike `getOperands`, which
210  /// returns an `ArrayRef` with one `ValueRange` for each original operand,
211  /// this function returns a `ValueRange` that contains all converted
212  /// operands irrespectively of which operand they originated from.
213  ValueRange getFlatOperands() const { return *convertedOperands; }
214 
215  private:
216  const OneToNTypeMapping *operandMapping;
217  const OneToNTypeMapping *resultMapping;
218  const ValueRange *convertedOperands;
219  };
220 
222 
223  /// Overload that derived classes have to override for their op type.
224  virtual LogicalResult
225  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
226  OneToNPatternRewriter &rewriter) const = 0;
227 
228  LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
229  const OneToNTypeMapping &operandMapping,
230  const OneToNTypeMapping &resultMapping,
231  ValueRange convertedOperands) const final {
232  // Wrap converted operands and type mappings into an adaptor.
233  SmallVector<ValueRange> valueRanges;
234  for (int64_t i = 0; i < op->getNumOperands(); i++) {
235  auto values = operandMapping.getConvertedValues(convertedOperands, i);
236  valueRanges.push_back(values);
237  }
238  OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
239  valueRanges, cast<SourceOp>(op));
240 
241  // Call overload implemented by the derived class.
242  return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
243  }
244 };
245 
246 /// Applies the given set of patterns recursively on the given op and adds user
247 /// materializations where necessary. The patterns are expected to be
248 /// `OneToNConversionPattern`, which help converting the types of the operands
249 /// and results of the matched ops. The provided type converter is used to
250 /// convert the operands of matched ops from their original types to operands
251 /// with different types. Unlike in `DialectConversion`, this supports 1:N type
252 /// conversions. Those conversions at the "boundary" of the pattern application,
253 /// where converted results are not consumed by replaced ops that expect the
254 /// converted operands or vice versa, the function inserts user materializations
255 /// from the type converter. Also unlike `DialectConversion`, there are no legal
256 /// or illegal types; the function simply applies the given patterns and does
257 /// not fail if some ops or types remain unconverted (i.e., the conversion is
258 /// only "partial").
259 /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
260 /// 1:N support has been added to the regular dialect conversion driver.
261 /// Use applyPartialConversion() instead.
262 LogicalResult
263 applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
264  const FrozenRewritePatternSet &patterns);
265 
266 /// Add a pattern to the given pattern list to convert the signature of a
267 /// FunctionOpInterface op with the given type converter. This only supports
268 /// ops which use FunctionType to represent their type. This is intended to be
269 /// used with the 1:N dialect conversion.
270 /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
271 /// 1:N support has been added to the regular dialect conversion driver.
272 /// Use populateFunctionOpInterfaceTypeConversionPattern() instead.
274  StringRef functionLikeOpName, const TypeConverter &converter,
275  RewritePatternSet &patterns);
276 template <typename FuncOpT>
278  const TypeConverter &converter, RewritePatternSet &patterns) {
280  FuncOpT::getOperationName(), converter, patterns);
281 }
282 
283 } // namespace mlir
284 
285 #endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
Block represents an ordered list of Operations.
Definition: Block.h:33
MLIRContext * context
Definition: Builders.h:200
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Base class for patterns with 1:N type conversions.
virtual LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const =0
This function has to be implemented by derived classes and is called from the usual overloads.
Generic adaptor around the root op of this pattern using the converted operands.
const OneToNTypeMapping & getOperandMapping() const
Get the type mapping of the original operands to the converted operands.
ValueRange getFlatOperands() const
Get a flat range of all converted operands.
typename SourceOp::template InferredProperties< SourceOp > Properties
OpAdaptor(const OneToNTypeMapping *operandMapping, const OneToNTypeMapping *resultMapping, const ValueRange *convertedOperands, RangeT values, SourceOp op)
const OneToNTypeMapping & getResultMapping() const
Get the type mapping of the original results to the converted results.
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
virtual LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const=0
This function has to be implemented by derived classes and is called from the usual overloads.
OneToNOpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const =0
Overload that derived classes have to override for their op type.
LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const final
This function has to be implemented by derived classes and is called from the usual overloads.
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
OneToNPatternRewriter(MLIRContext *context, OpBuilder::Listener *listener=nullptr)
Stores a 1:N mapping of types and provides several useful accessors.
OneToNTypeMapping(TypeRange originalTypes)
void convertLocations(ValueRange originalValues, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the lociation of each original value as the numb...
TypeRange getOriginalTypes() const
Returns the list of original types.
bool hasNonIdentityConversion() const
Returns true iff at least one type conversion maps an input type to a type that is different from its...
ValueRange getConvertedValues(ValueRange convertedValues, unsigned originalValueNo) const
Returns the slice of converted values that corresponds the original value at the given index.
void convertLocation(Value originalValue, unsigned originalValueNo, llvm::SmallVectorImpl< Location > &result) const
Fills the given result vector with as many copies of the location of the original value as the number...
Listener * listener
The optional listener for events of this builder.
Definition: Builders.h:605
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
Extends the basic RewritePattern class with a type converter member and some accessors to it.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
const TypeConverter *const typeConverter
A type converter for use by this pattern.
std::enable_if_t< std::is_base_of< TypeConverter, ConverterTy >::value, const ConverterTy * > getTypeConverter() const
RewritePatternWithConverter(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides all of the information necessary to convert a type signature.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns)
Applies the given set of patterns recursively on the given op and adds user materializations where ne...
void populateOneToNFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, const TypeConverter &converter, RewritePatternSet &patterns)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283