MLIR  21.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
16 
17 namespace mlir {
18 class CallOpInterface;
19 
20 namespace LLVM {
21 namespace detail {
22 /// Handle generically setting flags as native properties on LLVM operations.
23 void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
24 
25 /// Replaces the given operation "op" with a new operation of type "targetOp"
26 /// and given operands.
27 LogicalResult oneToOneRewrite(
28  Operation *op, StringRef targetOp, ValueRange operands,
29  ArrayRef<NamedAttribute> targetAttrs,
30  const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
31  IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
32 
33 } // namespace detail
34 } // namespace LLVM
35 
36 /// Base class for operation conversions targeting the LLVM IR dialect. It
37 /// provides the conversion patterns with access to the LLVMTypeConverter and
38 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
39 /// LowerToLLVMOptions by reference meaning the references have to remain alive
40 /// during the entire pattern lifetime.
42 public:
43  /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
44  /// separate `match` and `rewrite`.
47 
48  ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
50  PatternBenefit benefit = 1);
51 
52 protected:
53  /// Returns the LLVM dialect.
54  LLVM::LLVMDialect &getDialect() const;
55 
56  const LLVMTypeConverter *getTypeConverter() const;
57 
58  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
59  /// defined by the used type converter.
60  Type getIndexType() const;
61 
62  /// Gets the MLIR type wrapping the LLVM integer type whose bit width
63  /// corresponds to that of a LLVM pointer type.
64  Type getIntPtrType(unsigned addressSpace = 0) const;
65 
66  /// Gets the MLIR type wrapping the LLVM void type.
67  Type getVoidType() const;
68 
69  /// Get the MLIR type wrapping the LLVM i8* type.
70  Type getVoidPtrType() const;
71 
72  /// Create a constant Op producing a value of `resultType` from an index-typed
73  /// integer attribute.
74  static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
75  Type resultType, int64_t value);
76 
77  // This is a strided getElementPtr variant that linearizes subscripts as:
78  // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
79  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
80  ValueRange indices,
81  ConversionPatternRewriter &rewriter) const;
82 
83  /// Returns if the given memref has identity maps and the element type is
84  /// convertible to LLVM.
85  bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
86 
87  /// Returns the type of a pointer to an element of the memref.
88  Type getElementPtrType(MemRefType type) const;
89 
90  /// Computes sizes, strides and buffer size of `memRefType` with identity
91  /// layout. Emits constant ops for the static sizes of `memRefType`, and uses
92  /// `dynamicSizes` for the others. Emits instructions to compute strides and
93  /// buffer size from these sizes.
94  ///
95  /// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
96  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
97  /// `sizes[1]` = `dynamicSizes[0]`
98  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
99  /// `strides[0]` = `sizes[0]`
100  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
101  /// %nullptr = llvm.mlir.zero : !llvm.ptr
102  /// %gep = llvm.getelementptr %nullptr[%size]
103  /// : (!llvm.ptr, i64) -> !llvm.ptr, f32
104  /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64
105  ///
106  /// If `sizeInBytes = false`, memref<4x?xf32> emits:
107  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
108  /// `sizes[1]` = `dynamicSizes[0]`
109  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
110  /// `strides[0]` = `sizes[0]`
111  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
112  void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
113  ValueRange dynamicSizes,
114  ConversionPatternRewriter &rewriter,
115  SmallVectorImpl<Value> &sizes,
116  SmallVectorImpl<Value> &strides, Value &size,
117  bool sizeInBytes = true) const;
118 
119  /// Computes the size of type in bytes.
120  Value getSizeInBytes(Location loc, Type type,
121  ConversionPatternRewriter &rewriter) const;
122 
123  /// Computes total number of elements for the given MemRef and dynamicSizes.
124  Value getNumElements(Location loc, MemRefType memRefType,
125  ValueRange dynamicSizes,
126  ConversionPatternRewriter &rewriter) const;
127 
128  /// Creates and populates a canonical memref descriptor struct.
130  createMemRefDescriptor(Location loc, MemRefType memRefType,
131  Value allocatedPtr, Value alignedPtr,
132  ArrayRef<Value> sizes, ArrayRef<Value> strides,
133  ConversionPatternRewriter &rewriter) const;
134 
135  /// Copies the memory descriptor for any operands that were unranked
136  /// descriptors originally to heap-allocated memory (if toDynamic is true) or
137  /// to stack-allocated memory (otherwise). Also frees the previously used
138  /// memory (that is assumed to be heap-allocated) if toDynamic is false.
139  LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
140  TypeRange origTypes,
141  SmallVectorImpl<Value> &operands,
142  bool toDynamic) const;
143 };
144 
145 /// Utility class for operation conversions targeting the LLVM dialect that
146 /// match exactly one source operation.
147 template <typename SourceOp>
149 public:
150  using OperationT = SourceOp;
151  using OpAdaptor = typename SourceOp::Adaptor;
153  typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
154 
155  /// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
156  /// separate `match` and `rewrite`.
159 
161  PatternBenefit benefit = 1)
162  : ConvertToLLVMPattern(SourceOp::getOperationName(),
164  benefit) {}
165 
166  /// Wrappers around the RewritePattern methods that pass the derived op type.
167  LogicalResult
169  ConversionPatternRewriter &rewriter) const final {
170  auto sourceOp = cast<SourceOp>(op);
171  return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
172  }
173  LogicalResult
175  ConversionPatternRewriter &rewriter) const final {
176  auto sourceOp = cast<SourceOp>(op);
177  return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
178  rewriter);
179  }
180 
181  /// Methods that operate on the SourceOp type. One of these must be
182  /// overridden by the derived pattern class.
183  virtual LogicalResult
184  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const {
186  llvm_unreachable("matchAndRewrite is not implemented");
187  }
188  virtual LogicalResult
189  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
190  ConversionPatternRewriter &rewriter) const {
191  SmallVector<Value> oneToOneOperands =
192  getOneToOneAdaptorOperands(adaptor.getOperands());
193  return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
194  }
195 
196 private:
198 };
199 
200 /// Generic implementation of one-to-one conversion from "SourceOp" to
201 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
202 /// Upholds a convention that multi-result operations get converted into an
203 /// operation returning the LLVM IR structure type, in which case individual
204 /// values must be extracted from using LLVM::ExtractValueOp before being used.
205 template <typename SourceOp, typename TargetOp>
207 public:
210 
211  /// Converts the type of the result to an LLVM type, pass operands as is,
212  /// preserve attributes.
213  LogicalResult
214  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
215  ConversionPatternRewriter &rewriter) const override {
216  return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
217  adaptor.getOperands(), op->getAttrs(),
218  *this->getTypeConverter(), rewriter);
219  }
220 };
221 
222 } // namespace mlir
223 
224 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
This class implements a pattern rewriter for use with ConversionPatterns.
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:148
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:160
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:168
virtual LogicalResult matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:189
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Methods that operate on the SourceOp type.
Definition: Pattern.h:184
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition: Pattern.h:153
LogicalResult matchAndRewrite(Operation *op, ArrayRef< ValueRange > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
Definition: Pattern.h:174
typename SourceOp::Adaptor OpAdaptor
Definition: Pattern.h:151
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:41
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:216
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition: Pattern.cpp:184
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:169
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:245
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:107
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:100
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:49
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:206
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition: Pattern.h:214
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Helper class that derives from a ConversionRewritePattern class and provides separate match and rewri...
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:345
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition: Pattern.cpp:337
Include the generated interface declarations.