MLIR 23.0.0git
Pattern.h
Go to the documentation of this file.
1//===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10#define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11
16
17namespace mlir {
18class CallOpInterface;
19
20namespace LLVM {
21namespace detail {
22/// Replaces the given operation "op" with a new operation of type "targetOp"
23/// and given operands.
24LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
25 ValueRange operands,
26 ArrayRef<NamedAttribute> targetAttrs,
27 Attribute propertiesAttr,
28 const LLVMTypeConverter &typeConverter,
29 ConversionPatternRewriter &rewriter);
30
31/// Replaces the given operation "op" with a call to an LLVM intrinsic with the
32/// specified name "intrinsic" and operands.
33///
34/// The rewrite performs a simple one-to-one matching between the op and LLVM
35/// intrinsic. For example:
36///
37/// ```mlir
38/// %res = intr.op %val : vector<16xf32>
39/// ```
40///
41/// can be converted to
42///
43/// ```mlir
44/// %res = llvm.call_intrinsic "intrinsic"(%val)
45/// ```
46///
47/// The provided operands must be LLVM-compatible.
48///
49/// Upholds a convention that multi-result operations get converted into an
50/// operation returning the LLVM IR structure type, in which case individual
51/// values are first extracted before replacing the original results.
52LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
53 ValueRange operands,
54 const LLVMTypeConverter &typeConverter,
55 RewriterBase &rewriter);
56
57/// Return "true" if the given type is an unsupported floating point type.
58/// In case of a vector type, return "true" if the element type is an
59/// unsupported floating point type.
60bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
61 Type type);
62/// Return "true" if the given op has any unsupported floating point
63/// types (either operands or results).
65 const TypeConverter &typeConverter);
66} // namespace detail
67
68/// Decomposes a `src` value into a set of values of type `dstType` through
69/// series of bitcasts and vector ops. Handles int, float, vector types as well
70/// as LLVM aggregate types (LLVMArrayType, LLVMStructType) by recursively
71/// extracting elements.
72///
73/// When a non-aggregate's bitwidth is not evenly divisible by the bitwidth of
74/// `dstType` width, the source value will be zero-extended to the next
75/// (multiple of) that bitwidth before decomposition.
76///
77/// When `permitVariablySizedScalars` is true, leaf types that have no fixed
78/// bit width (e.g., `!llvm.ptr`) are passed through as-is (1 element in
79/// result). When false (default), encountering such a type returns failure.
80LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src,
82 bool permitVariablySizedScalars = false);
83
84/// Composes a set of `src` values into a single value of type `dstType` through
85/// series of bitcasts and vector ops, and aggregate builders. This is the
86/// inverse of `decomposeValue` and expects the values in `src` to have the
87/// order and padding bits that that function would produce.
89 Type dstType);
90
91/// Performs the index computation to get to the element at `indices` of the
92/// memory pointed to by `memRefDesc`, using the layout map of `type`.
93/// The indices are linearized as:
94/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
96 OpBuilder &builder, Location loc, const LLVMTypeConverter &converter,
97 MemRefType type, Value memRefDesc, ValueRange indices,
98 LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none);
99} // namespace LLVM
100
101/// Base class for operation conversions targeting the LLVM IR dialect. It
102/// provides the conversion patterns with access to the LLVMTypeConverter and
103/// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
104/// LowerToLLVMOptions by reference meaning the references have to remain alive
105/// during the entire pattern lifetime.
106class ConvertToLLVMPattern : public ConversionPattern {
107public:
108 ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
109 const LLVMTypeConverter &typeConverter,
110 PatternBenefit benefit = 1);
111
112protected:
113 /// See `ConversionPattern::ConversionPattern` for information on the other
114 /// available constructors.
115 using ConversionPattern::ConversionPattern;
116
117 /// Returns the LLVM dialect.
118 LLVM::LLVMDialect &getDialect() const;
119
120 const LLVMTypeConverter *getTypeConverter() const;
121
122 /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
123 /// defined by the used type converter.
124 Type getIndexType() const;
125
126 /// Gets the MLIR type wrapping the LLVM integer type whose bit width
127 /// corresponds to that of a LLVM pointer type.
128 Type getIntPtrType(unsigned addressSpace = 0) const;
129
130 /// Gets the MLIR type wrapping the LLVM void type.
131 Type getVoidType() const;
132
133 /// Get the MLIR type wrapping the LLVM i8* type.
134 [[deprecated("Use getPtrType() instead!")]]
135 Type getVoidPtrType() const;
136
137 /// Get the MLIR type wrapping the LLVM ptr type.
138 Type getPtrType(unsigned addressSpace = 0) const;
139
140 /// Create a constant Op producing a value of `resultType` from an index-typed
141 /// integer attribute.
142 static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
143 Type resultType, int64_t value);
144
145 /// Convenience wrapper for the corresponding helper utility.
146 /// This is a strided getElementPtr variant with linearized subscripts.
148 ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
149 Value memRefDesc, ValueRange indices,
150 LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
151
152 /// Returns if the given memref type is convertible to LLVM and has an
153 /// identity layout map.
154 bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
155
156 /// Returns the type of a pointer to an element of the memref.
157 Type getElementPtrType(MemRefType type) const;
158
159 /// Computes sizes, strides and buffer size of `memRefType` with identity
160 /// layout. Emits constant ops for the static sizes of `memRefType`, and uses
161 /// `dynamicSizes` for the others. Emits instructions to compute strides and
162 /// buffer size from these sizes.
163 ///
164 /// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
165 /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
166 /// `sizes[1]` = `dynamicSizes[0]`
167 /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
168 /// `strides[0]` = `sizes[0]`
169 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
170 /// %nullptr = llvm.mlir.zero : !llvm.ptr
171 /// %gep = llvm.getelementptr %nullptr[%size]
172 /// : (!llvm.ptr, i64) -> !llvm.ptr, f32
173 /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64
174 ///
175 /// If `sizeInBytes = false`, memref<4x?xf32> emits:
176 /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
177 /// `sizes[1]` = `dynamicSizes[0]`
178 /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
179 /// `strides[0]` = `sizes[0]`
180 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
181 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
182 ValueRange dynamicSizes,
183 ConversionPatternRewriter &rewriter,
185 SmallVectorImpl<Value> &strides, Value &size,
186 bool sizeInBytes = true) const;
187
188 /// Computes the size of type in bytes.
190 ConversionPatternRewriter &rewriter) const;
191
192 /// Computes total number of elements for the given MemRef and dynamicSizes.
193 Value getNumElements(Location loc, MemRefType memRefType,
194 ValueRange dynamicSizes,
195 ConversionPatternRewriter &rewriter) const;
196
197 /// Creates and populates a canonical memref descriptor struct.
199 createMemRefDescriptor(Location loc, MemRefType memRefType,
200 Value allocatedPtr, Value alignedPtr,
201 ArrayRef<Value> sizes, ArrayRef<Value> strides,
202 ConversionPatternRewriter &rewriter) const;
203
204 /// Copies the given unranked memory descriptor to heap-allocated memory (if
205 /// toDynamic is true) or to stack-allocated memory (otherwise) and returns
206 /// the new descriptor. Also frees the previously used memory (that is assumed
207 /// to be heap-allocated) if toDynamic is false. Returns a "null" SSA value
208 /// on failure.
210 UnrankedMemRefType memRefType, Value operand,
211 bool toDynamic) const;
212
213 /// Copies the memory descriptor for any operands that were unranked
214 /// descriptors originally to heap-allocated memory (if toDynamic is true) or
215 /// to stack-allocated memory (otherwise). The vector of descriptors is
216 /// updated in place. Also frees the previously used memory (that is assumed
217 /// to be heap-allocated) if toDynamic is false.
218 LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
219 TypeRange origTypes,
220 SmallVectorImpl<Value> &operands,
221 bool toDynamic) const;
222};
223
224/// Utility class for operation conversions targeting the LLVM dialect that
225/// match exactly one source operation.
226template <typename SourceOp, bool FailOnUnsupportedFP = false>
228public:
229 using OpAdaptor = typename SourceOp::Adaptor;
231 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
232
233 explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
234 PatternBenefit benefit = 1)
235 : ConvertToLLVMPattern(SourceOp::getOperationName(),
236 &typeConverter.getContext(), typeConverter,
237 benefit) {}
238
239 /// Wrappers around the RewritePattern methods that pass the derived op type.
240 LogicalResult
242 ConversionPatternRewriter &rewriter) const final {
243 // Bail on unsupported floating point types. (These are type-converted to
244 // integer types.)
245 if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
246 op, *this->typeConverter)) {
247 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
248 }
249 auto sourceOp = cast<SourceOp>(op);
250 return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
251 }
252 LogicalResult
254 ConversionPatternRewriter &rewriter) const final {
255 // Bail on unsupported floating point types. (These are type-converted to
256 // integer types.)
257 if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
258 op, *this->typeConverter)) {
259 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
260 }
261 auto sourceOp = cast<SourceOp>(op);
262 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
263 rewriter);
264 }
265
266 /// Methods that operate on the SourceOp type. One of these must be
267 /// overridden by the derived pattern class.
268 virtual LogicalResult
269 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
270 ConversionPatternRewriter &rewriter) const {
271 llvm_unreachable("matchAndRewrite is not implemented");
272 }
273 virtual LogicalResult
274 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const {
276 return dispatchTo1To1(*this, op, adaptor, rewriter);
277 }
278
279private:
280 using ConvertToLLVMPattern::matchAndRewrite;
281};
282
283/// Utility class for operation conversions targeting the LLVM dialect that
284/// allows for matching and rewriting against an instance of an OpInterface
285/// class.
286template <typename SourceOp>
288public:
290 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
291 : ConvertToLLVMPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
292 SourceOp::getInterfaceID(), benefit,
293 &typeConverter.getContext()) {}
294
295 /// Wrappers around the RewritePattern methods that pass the derived op type.
296 LogicalResult
298 ConversionPatternRewriter &rewriter) const final {
299 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
300 }
301 LogicalResult
303 ConversionPatternRewriter &rewriter) const final {
304 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
305 }
306
307 /// Methods that operate on the SourceOp type. One of these must be
308 /// overridden by the derived pattern class.
309 virtual LogicalResult
310 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
311 ConversionPatternRewriter &rewriter) const {
312 llvm_unreachable("matchAndRewrite is not implemented");
313 }
314 virtual LogicalResult
316 ConversionPatternRewriter &rewriter) const {
317 return dispatchTo1To1(*this, op, operands, rewriter);
318 }
319
320private:
321 using ConvertToLLVMPattern::matchAndRewrite;
322};
323
324/// Generic implementation of one-to-one conversion from "SourceOp" to
325/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
326/// Upholds a convention that multi-result operations get converted into an
327/// operation returning the LLVM IR structure type, in which case individual
328/// values must be extracted from using LLVM::ExtractValueOp before being used.
329template <typename SourceOp, typename TargetOp>
331public:
334
335 /// Converts the type of the result to an LLVM type, pass operands as is,
336 /// preserve attributes.
337 LogicalResult
338 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override {
341 op, TargetOp::getOperationName(), adaptor.getOperands(), op->getAttrs(),
342 /*propertiesAttr=*/Attribute{}, *this->getTypeConverter(), rewriter);
343 }
344};
345
346} // namespace mlir
347
348#endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition Pattern.h:297
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
Definition Pattern.h:315
ConvertOpInterfaceToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:289
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
Definition Pattern.h:310
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Definition Pattern.h:302
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition Pattern.h:241
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
Definition Pattern.h:269
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition Pattern.h:274
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Definition Pattern.h:253
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:229
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition Pattern.cpp:47
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition Pattern.cpp:202
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.cpp:24
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:66
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition Pattern.cpp:90
Type getPtrType(unsigned addressSpace=0) const
Get the MLIR type wrapping the LLVM ptr type.
Definition Pattern.cpp:51
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition Pattern.cpp:38
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:29
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition Pattern.cpp:170
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition Pattern.cpp:34
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition Pattern.cpp:155
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition Pattern.cpp:42
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const
Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to sta...
Definition Pattern.cpp:231
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition Pattern.cpp:290
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition Pattern.cpp:83
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition Pattern.cpp:58
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Definition Pattern.cpp:76
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition Pattern.cpp:56
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition Pattern.h:330
OneToOneConvertToLLVMPattern< SourceOp, TargetOp > Super
Definition Pattern.h:333
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition Pattern.h:338
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
Definition Pattern.cpp:662
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:313
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
Definition Pattern.cpp:673
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
Definition Pattern.cpp:352
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType, SmallVectorImpl< Value > &result, bool permitVariablySizedScalars=false)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:495
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:594
Include the generated interface declarations.