MLIR 22.0.0git
Pattern.h
Go to the documentation of this file.
1//===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10#define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11
16
17namespace mlir {
18class CallOpInterface;
19
20namespace LLVM {
21namespace detail {
22/// Handle generically setting flags as native properties on LLVM operations.
23void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
24
25/// Replaces the given operation "op" with a new operation of type "targetOp"
26/// and given operands.
27LogicalResult oneToOneRewrite(
28 Operation *op, StringRef targetOp, ValueRange operands,
29 ArrayRef<NamedAttribute> targetAttrs,
30 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
31 IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
32
33/// Replaces the given operation "op" with a call to an LLVM intrinsic with the
34/// specified name "intrinsic" and operands.
35///
36/// The rewrite performs a simple one-to-one matching between the op and LLVM
37/// intrinsic. For example:
38///
39/// ```mlir
40/// %res = intr.op %val : vector<16xf32>
41/// ```
42///
43/// can be converted to
44///
45/// ```mlir
46/// %res = llvm.call_intrinsic "intrinsic"(%val)
47/// ```
48///
49/// The provided operands must be LLVM-compatible.
50///
51/// Upholds a convention that multi-result operations get converted into an
52/// operation returning the LLVM IR structure type, in which case individual
53/// values are first extracted before replacing the original results.
54LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
55 ValueRange operands,
56 const LLVMTypeConverter &typeConverter,
57 RewriterBase &rewriter);
58
59} // namespace detail
60
61/// Decomposes a `src` value into a set of values of type `dstType` through
62/// series of bitcasts and vector ops. Src and dst types are expected to be int
63/// or float types or vector types of them.
65 Type dstType);
66
67/// Composes a set of `src` values into a single value of type `dstType` through
68/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
69/// function is used to combine multiple values into a single value.
71 Type dstType);
72
73/// Performs the index computation to get to the element at `indices` of the
74/// memory pointed to by `memRefDesc`, using the layout map of `type`.
75/// The indices are linearized as:
76/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
78 OpBuilder &builder, Location loc, const LLVMTypeConverter &converter,
79 MemRefType type, Value memRefDesc, ValueRange indices,
80 LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none);
81} // namespace LLVM
82
83/// Base class for operation conversions targeting the LLVM IR dialect. It
84/// provides the conversion patterns with access to the LLVMTypeConverter and
85/// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
86/// LowerToLLVMOptions by reference meaning the references have to remain alive
87/// during the entire pattern lifetime.
89public:
90 ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
91 const LLVMTypeConverter &typeConverter,
92 PatternBenefit benefit = 1);
93
94protected:
95 /// See `ConversionPattern::ConversionPattern` for information on the other
96 /// available constructors.
97 using ConversionPattern::ConversionPattern;
98
99 /// Returns the LLVM dialect.
100 LLVM::LLVMDialect &getDialect() const;
101
102 const LLVMTypeConverter *getTypeConverter() const;
103
104 /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
105 /// defined by the used type converter.
106 Type getIndexType() const;
107
108 /// Gets the MLIR type wrapping the LLVM integer type whose bit width
109 /// corresponds to that of a LLVM pointer type.
110 Type getIntPtrType(unsigned addressSpace = 0) const;
111
112 /// Gets the MLIR type wrapping the LLVM void type.
113 Type getVoidType() const;
114
115 /// Get the MLIR type wrapping the LLVM i8* type.
116 [[deprecated("Use getPtrType() instead!")]]
117 Type getVoidPtrType() const;
118
119 /// Get the MLIR type wrapping the LLVM ptr type.
120 Type getPtrType(unsigned addressSpace = 0) const;
121
122 /// Create a constant Op producing a value of `resultType` from an index-typed
123 /// integer attribute.
124 static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
125 Type resultType, int64_t value);
126
127 /// Convenience wrapper for the corresponding helper utility.
128 /// This is a strided getElementPtr variant with linearized subscripts.
130 ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
131 Value memRefDesc, ValueRange indices,
132 LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
133
134 /// Returns if the given memref type is convertible to LLVM and has an
135 /// identity layout map.
136 bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
137
138 /// Returns the type of a pointer to an element of the memref.
139 Type getElementPtrType(MemRefType type) const;
140
141 /// Computes sizes, strides and buffer size of `memRefType` with identity
142 /// layout. Emits constant ops for the static sizes of `memRefType`, and uses
143 /// `dynamicSizes` for the others. Emits instructions to compute strides and
144 /// buffer size from these sizes.
145 ///
146 /// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
147 /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
148 /// `sizes[1]` = `dynamicSizes[0]`
149 /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
150 /// `strides[0]` = `sizes[0]`
151 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
152 /// %nullptr = llvm.mlir.zero : !llvm.ptr
153 /// %gep = llvm.getelementptr %nullptr[%size]
154 /// : (!llvm.ptr, i64) -> !llvm.ptr, f32
155 /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64
156 ///
157 /// If `sizeInBytes = false`, memref<4x?xf32> emits:
158 /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
159 /// `sizes[1]` = `dynamicSizes[0]`
160 /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
161 /// `strides[0]` = `sizes[0]`
162 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
163 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
164 ValueRange dynamicSizes,
165 ConversionPatternRewriter &rewriter,
167 SmallVectorImpl<Value> &strides, Value &size,
168 bool sizeInBytes = true) const;
169
170 /// Computes the size of type in bytes.
172 ConversionPatternRewriter &rewriter) const;
173
174 /// Computes total number of elements for the given MemRef and dynamicSizes.
175 Value getNumElements(Location loc, MemRefType memRefType,
176 ValueRange dynamicSizes,
177 ConversionPatternRewriter &rewriter) const;
178
179 /// Creates and populates a canonical memref descriptor struct.
181 createMemRefDescriptor(Location loc, MemRefType memRefType,
182 Value allocatedPtr, Value alignedPtr,
183 ArrayRef<Value> sizes, ArrayRef<Value> strides,
184 ConversionPatternRewriter &rewriter) const;
185
186 /// Copies the given unranked memory descriptor to heap-allocated memory (if
187 /// toDynamic is true) or to stack-allocated memory (otherwise) and returns
188 /// the new descriptor. Also frees the previously used memory (that is assumed
189 /// to be heap-allocated) if toDynamic is false. Returns a "null" SSA value
190 /// on failure.
192 UnrankedMemRefType memRefType, Value operand,
193 bool toDynamic) const;
194
195 /// Copies the memory descriptor for any operands that were unranked
196 /// descriptors originally to heap-allocated memory (if toDynamic is true) or
197 /// to stack-allocated memory (otherwise). The vector of descriptors is
198 /// updated in place. Also frees the previously used memory (that is assumed
199 /// to be heap-allocated) if toDynamic is false.
200 LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
201 TypeRange origTypes,
202 SmallVectorImpl<Value> &operands,
203 bool toDynamic) const;
204};
205
206/// Utility class for operation conversions targeting the LLVM dialect that
207/// match exactly one source operation.
208template <typename SourceOp>
210public:
211 using OpAdaptor = typename SourceOp::Adaptor;
213 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
214
215 explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
216 PatternBenefit benefit = 1)
217 : ConvertToLLVMPattern(SourceOp::getOperationName(),
218 &typeConverter.getContext(), typeConverter,
219 benefit) {}
220
221 /// Wrappers around the RewritePattern methods that pass the derived op type.
222 LogicalResult
224 ConversionPatternRewriter &rewriter) const final {
225 auto sourceOp = cast<SourceOp>(op);
226 return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
227 }
228 LogicalResult
230 ConversionPatternRewriter &rewriter) const final {
231 auto sourceOp = cast<SourceOp>(op);
232 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
233 rewriter);
234 }
235
236 /// Methods that operate on the SourceOp type. One of these must be
237 /// overridden by the derived pattern class.
238 virtual LogicalResult
239 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
240 ConversionPatternRewriter &rewriter) const {
241 llvm_unreachable("matchAndRewrite is not implemented");
242 }
243 virtual LogicalResult
244 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const {
246 return dispatchTo1To1(*this, op, adaptor, rewriter);
247 }
248
249private:
250 using ConvertToLLVMPattern::matchAndRewrite;
251};
252
253/// Utility class for operation conversions targeting the LLVM dialect that
254/// allows for matching and rewriting against an instance of an OpInterface
255/// class.
256template <typename SourceOp>
258public:
260 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
261 : ConvertToLLVMPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
262 SourceOp::getInterfaceID(), benefit,
263 &typeConverter.getContext()) {}
264
265 /// Wrappers around the RewritePattern methods that pass the derived op type.
266 LogicalResult
268 ConversionPatternRewriter &rewriter) const final {
269 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
270 }
271 LogicalResult
273 ConversionPatternRewriter &rewriter) const final {
274 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
275 }
276
277 /// Methods that operate on the SourceOp type. One of these must be
278 /// overridden by the derived pattern class.
279 virtual LogicalResult
280 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
281 ConversionPatternRewriter &rewriter) const {
282 llvm_unreachable("matchAndRewrite is not implemented");
283 }
284 virtual LogicalResult
286 ConversionPatternRewriter &rewriter) const {
287 return dispatchTo1To1(*this, op, operands, rewriter);
288 }
289
290private:
291 using ConvertToLLVMPattern::matchAndRewrite;
292};
293
294/// Generic implementation of one-to-one conversion from "SourceOp" to
295/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
296/// Upholds a convention that multi-result operations get converted into an
297/// operation returning the LLVM IR structure type, in which case individual
298/// values must be extracted from using LLVM::ExtractValueOp before being used.
299template <typename SourceOp, typename TargetOp>
301public:
304
305 /// Converts the type of the result to an LLVM type, pass operands as is,
306 /// preserve attributes.
307 LogicalResult
308 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
309 ConversionPatternRewriter &rewriter) const override {
310 return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
311 adaptor.getOperands(), op->getAttrs(),
312 *this->getTypeConverter(), rewriter);
313 }
314};
315
316} // namespace mlir
317
318#endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
b getContext())
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition Pattern.h:267
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const
Definition Pattern.h:285
ConvertOpInterfaceToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:259
virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
Definition Pattern.h:280
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Definition Pattern.h:272
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition Pattern.h:223
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition Pattern.h:244
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
Definition Pattern.h:239
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:212
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Definition Pattern.h:229
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:211
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition Pattern.cpp:190
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.cpp:22
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition Pattern.cpp:88
Type getPtrType(unsigned addressSpace=0) const
Get the MLIR type wrapping the LLVM ptr type.
Definition Pattern.cpp:49
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition Pattern.cpp:158
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition Pattern.cpp:143
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition Pattern.cpp:40
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const
Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to sta...
Definition Pattern.cpp:219
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition Pattern.cpp:278
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition Pattern.cpp:81
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition Pattern.cpp:56
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Definition Pattern.cpp:74
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition Pattern.cpp:54
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition Pattern.h:300
OneToOneConvertToLLVMPattern< SourceOp, TargetOp > Super
Definition Pattern.h:303
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition Pattern.h:308
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:307
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition Pattern.cpp:299
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
Definition Pattern.cpp:347
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:478
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:439
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:400
Include the generated interface declarations.