MLIR  22.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
16 
17 namespace mlir {
18 class CallOpInterface;
19 
20 namespace LLVM {
21 namespace detail {
22 /// Handle generically setting flags as native properties on LLVM operations.
23 void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
24 
25 /// Replaces the given operation "op" with a new operation of type "targetOp"
26 /// and given operands.
27 LogicalResult oneToOneRewrite(
28  Operation *op, StringRef targetOp, ValueRange operands,
29  ArrayRef<NamedAttribute> targetAttrs,
30  const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
31  IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
32 
33 /// Replaces the given operation "op" with a call to an LLVM intrinsic with the
34 /// specified name "intrinsic" and operands.
35 ///
36 /// The rewrite performs a simple one-to-one matching between the op and LLVM
37 /// intrinsic. For example:
38 ///
39 /// ```mlir
40 /// %res = intr.op %val : vector<16xf32>
41 /// ```
42 ///
43 /// can be converted to
44 ///
45 /// ```mlir
46 /// %res = llvm.call_intrinsic "intrinsic"(%val)
47 /// ```
48 ///
49 /// The provided operands must be LLVM-compatible.
50 ///
51 /// Upholds a convention that multi-result operations get converted into an
52 /// operation returning the LLVM IR structure type, in which case individual
53 /// values are first extracted before replacing the original results.
54 LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
55  ValueRange operands,
56  const LLVMTypeConverter &typeConverter,
57  RewriterBase &rewriter);
58 
59 } // namespace detail
60 
61 /// Decomposes a `src` value into a set of values of type `dstType` through
62 /// series of bitcasts and vector ops. Src and dst types are expected to be int
63 /// or float types or vector types of them.
65  Type dstType);
66 
67 /// Composes a set of `src` values into a single value of type `dstType` through
68 /// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
69 /// function is used to combine multiple values into a single value.
71  Type dstType);
72 
73 /// Performs the index computation to get to the element at `indices` of the
74 /// memory pointed to by `memRefDesc`, using the layout map of `type`.
75 /// The indices are linearized as:
76 /// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
78  OpBuilder &builder, Location loc, const LLVMTypeConverter &converter,
79  MemRefType type, Value memRefDesc, ValueRange indices,
80  LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none);
81 } // namespace LLVM
82 
83 /// Base class for operation conversions targeting the LLVM IR dialect. It
84 /// provides the conversion patterns with access to the LLVMTypeConverter and
85 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
86 /// LowerToLLVMOptions by reference meaning the references have to remain alive
87 /// during the entire pattern lifetime.
89 public:
90  ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
92  PatternBenefit benefit = 1);
93 
94 protected:
95  /// See `ConversionPattern::ConversionPattern` for information on the other
96  /// available constructors.
98 
99  /// Returns the LLVM dialect.
100  LLVM::LLVMDialect &getDialect() const;
101 
102  const LLVMTypeConverter *getTypeConverter() const;
103 
104  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
105  /// defined by the used type converter.
106  Type getIndexType() const;
107 
108  /// Gets the MLIR type wrapping the LLVM integer type whose bit width
109  /// corresponds to that of a LLVM pointer type.
110  Type getIntPtrType(unsigned addressSpace = 0) const;
111 
112  /// Gets the MLIR type wrapping the LLVM void type.
113  Type getVoidType() const;
114 
115  /// Get the MLIR type wrapping the LLVM i8* type.
116  [[deprecated("Use getPtrType() instead!")]]
117  Type getVoidPtrType() const;
118 
119  /// Get the MLIR type wrapping the LLVM ptr type.
120  Type getPtrType(unsigned addressSpace = 0) const;
121 
122  /// Create a constant Op producing a value of `resultType` from an index-typed
123  /// integer attribute.
124  static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
125  Type resultType, int64_t value);
126 
127  /// Convenience wrapper for the corresponding helper utility.
128  /// This is a strided getElementPtr variant with linearized subscripts.
130  ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
131  Value memRefDesc, ValueRange indices,
132  LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
133 
134  /// Returns if the given memref type is convertible to LLVM and has an
135  /// identity layout map.
136  bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
137 
138  /// Returns the type of a pointer to an element of the memref.
139  Type getElementPtrType(MemRefType type) const;
140 
141  /// Computes sizes, strides and buffer size of `memRefType` with identity
142  /// layout. Emits constant ops for the static sizes of `memRefType`, and uses
143  /// `dynamicSizes` for the others. Emits instructions to compute strides and
144  /// buffer size from these sizes.
145  ///
146  /// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
147  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
148  /// `sizes[1]` = `dynamicSizes[0]`
149  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
150  /// `strides[0]` = `sizes[0]`
151  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
152  /// %nullptr = llvm.mlir.zero : !llvm.ptr
153  /// %gep = llvm.getelementptr %nullptr[%size]
154  /// : (!llvm.ptr, i64) -> !llvm.ptr, f32
155  /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64
156  ///
157  /// If `sizeInBytes = false`, memref<4x?xf32> emits:
158  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
159  /// `sizes[1]` = `dynamicSizes[0]`
160  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
161  /// `strides[0]` = `sizes[0]`
162  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
163  void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
164  ValueRange dynamicSizes,
165  ConversionPatternRewriter &rewriter,
166  SmallVectorImpl<Value> &sizes,
167  SmallVectorImpl<Value> &strides, Value &size,
168  bool sizeInBytes = true) const;
169 
170  /// Computes the size of type in bytes.
171  Value getSizeInBytes(Location loc, Type type,
172  ConversionPatternRewriter &rewriter) const;
173 
174  /// Computes total number of elements for the given MemRef and dynamicSizes.
175  Value getNumElements(Location loc, MemRefType memRefType,
176  ValueRange dynamicSizes,
177  ConversionPatternRewriter &rewriter) const;
178 
179  /// Creates and populates a canonical memref descriptor struct.
181  createMemRefDescriptor(Location loc, MemRefType memRefType,
182  Value allocatedPtr, Value alignedPtr,
183  ArrayRef<Value> sizes, ArrayRef<Value> strides,
184  ConversionPatternRewriter &rewriter) const;
185 
186  /// Copies the memory descriptor for any operands that were unranked
187  /// descriptors originally to heap-allocated memory (if toDynamic is true) or
188  /// to stack-allocated memory (otherwise). Also frees the previously used
189  /// memory (that is assumed to be heap-allocated) if toDynamic is false.
190  LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
191  TypeRange origTypes,
192  SmallVectorImpl<Value> &operands,
193  bool toDynamic) const;
194 };
195 
196 /// Utility class for operation conversions targeting the LLVM dialect that
197 /// match exactly one source operation.
198 template <typename SourceOp>
200 public:
201  using OpAdaptor = typename SourceOp::Adaptor;
203  typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
204 
206  PatternBenefit benefit = 1)
207  : ConvertToLLVMPattern(SourceOp::getOperationName(),
209  benefit) {}
210 
211  /// Wrappers around the RewritePattern methods that pass the derived op type.
212  LogicalResult
214  ConversionPatternRewriter &rewriter) const final {
215  auto sourceOp = cast<SourceOp>(op);
216  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
217  }
218  LogicalResult
220  ConversionPatternRewriter &rewriter) const final {
221  auto sourceOp = cast<SourceOp>(op);
222  return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
223  rewriter);
224  }
225 
226  /// Methods that operate on the SourceOp type. One of these must be
227  /// overridden by the derived pattern class.
228  virtual LogicalResult
229  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
230  ConversionPatternRewriter &rewriter) const {
231  llvm_unreachable("matchAndRewrite is not implemented");
232  }
233  virtual LogicalResult
234  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
235  ConversionPatternRewriter &rewriter) const {
236  SmallVector<Value> oneToOneOperands =
237  getOneToOneAdaptorOperands(adaptor.getOperands());
238  return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
239  }
240 
241 private:
243 };
244 
245 /// Utility class for operation conversions targeting the LLVM dialect that
246 /// allows for matching and rewriting against an instance of an OpInterface
247 /// class.
248 template <typename SourceOp>
250 public:
252  const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
254  SourceOp::getInterfaceID(), benefit,
256 
257  /// Wrappers around the RewritePattern methods that pass the derived op type.
258  LogicalResult
260  ConversionPatternRewriter &rewriter) const final {
261  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
262  }
263  LogicalResult
265  ConversionPatternRewriter &rewriter) const final {
266  return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
267  }
268 
269  /// Methods that operate on the SourceOp type. One of these must be
270  /// overridden by the derived pattern class.
271  virtual LogicalResult
272  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
273  ConversionPatternRewriter &rewriter) const {
274  llvm_unreachable("matchAndRewrite is not implemented");
275  }
276  virtual LogicalResult
277  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
278  ConversionPatternRewriter &rewriter) const {
279  return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
280  }
281 
282 private:
284 };
285 
286 /// Generic implementation of one-to-one conversion from "SourceOp" to
287 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
288 /// Upholds a convention that multi-result operations get converted into an
289 /// operation returning the LLVM IR structure type, in which case individual
290 /// values must be extracted from using LLVM::ExtractValueOp before being used.
291 template <typename SourceOp, typename TargetOp>
293 public:
296 
297  /// Converts the type of the result to an LLVM type, pass operands as is,
298  /// preserve attributes.
299  LogicalResult
300  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
301  ConversionPatternRewriter &rewriter) const override {
302  return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
303  adaptor.getOperands(), op->getAttrs(),
304  *this->getTypeConverter(), rewriter);
305  }
306 };
307 
308 } // namespace mlir
309 
310 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
This class implements a pattern rewriter for use with ConversionPatterns.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
Construct a conversion pattern with the given converter, and forward the remaining arguments to Rewri...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Utility class for operation conversions targeting the LLVM dialect that allows for matching and rewri...
Definition: Pattern.h:249
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:259
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:277
ConvertOpInterfaceToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:251
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
Definition: Pattern.h:272
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
Definition: Pattern.h:264
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:205
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:213
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:234
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
Definition: Pattern.h:229
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition: Pattern.h:203
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
Definition: Pattern.h:219
typename SourceOp::Adaptor OpAdaptor
Definition: Pattern.h:201
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:88
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:190
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition: Pattern.cpp:64
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:88
Type getPtrType(unsigned addressSpace=0) const
Get the MLIR type wrapping the LLVM ptr type.
Definition: Pattern.cpp:49
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition: Pattern.cpp:158
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:143
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:219
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:81
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:56
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Definition: Pattern.cpp:74
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:54
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:292
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition: Pattern.h:300
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:318
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition: Pattern.cpp:310
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
Definition: Pattern.cpp:358
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:489
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:450
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:411
Include the generated interface declarations.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Definition: PatternMatch.h:164