MLIR  17.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
15 
16 namespace mlir {
17 class CallOpInterface;
18 
19 namespace LLVM {
20 namespace detail {
21 /// Replaces the given operation "op" with a new operation of type "targetOp"
22 /// and given operands.
23 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
24  ValueRange operands,
25  ArrayRef<NamedAttribute> targetAttrs,
26  LLVMTypeConverter &typeConverter,
27  ConversionPatternRewriter &rewriter);
28 
29 } // namespace detail
30 } // namespace LLVM
31 
32 /// Base class for operation conversions targeting the LLVM IR dialect. It
33 /// provides the conversion patterns with access to the LLVMTypeConverter and
34 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
35 /// LowerToLLVMOptions by reference meaning the references have to remain alive
36 /// during the entire pattern lifetime.
38 public:
39  ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
41  PatternBenefit benefit = 1);
42 
43 protected:
44  /// Returns the LLVM dialect.
45  LLVM::LLVMDialect &getDialect() const;
46 
48 
49  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
50  /// defined by the used type converter.
51  Type getIndexType() const;
52 
53  /// Gets the MLIR type wrapping the LLVM integer type whose bit width
54  /// corresponds to that of a LLVM pointer type.
55  Type getIntPtrType(unsigned addressSpace = 0) const;
56 
57  /// Gets the MLIR type wrapping the LLVM void type.
58  Type getVoidType() const;
59 
60  /// Get the MLIR type wrapping the LLVM i8* type.
61  Type getVoidPtrType() const;
62 
63  /// Create a constant Op producing a value of `resultType` from an index-typed
64  /// integer attribute.
65  static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
66  Type resultType, int64_t value);
67 
68  /// Create an LLVM dialect operation defining the given index constant.
70  uint64_t value) const;
71 
72  // This is a strided getElementPtr variant that linearizes subscripts as:
73  // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
74  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
75  ValueRange indices,
76  ConversionPatternRewriter &rewriter) const;
77 
78  /// Returns if the given memref has identity maps and the element type is
79  /// convertible to LLVM.
80  bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
81 
82  /// Returns the type of a pointer to an element of the memref.
83  Type getElementPtrType(MemRefType type) const;
84 
85  /// Computes sizes, strides and buffer size in bytes of `memRefType` with
86  /// identity layout. Emits constant ops for the static sizes of `memRefType`,
87  /// and uses `dynamicSizes` for the others. Emits instructions to compute
88  /// strides and buffer size from these sizes.
89  ///
90  /// For example, memref<4x?xf32> emits:
91  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
92  /// `sizes[1]` = `dynamicSizes[0]`
93  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
94  /// `strides[0]` = `sizes[0]`
95  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
96  /// %nullptr = llvm.mlir.null : !llvm.ptr<f32>
97  /// %gep = llvm.getelementptr %nullptr[%size]
98  /// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
99  /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
100  void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
101  ValueRange dynamicSizes,
102  ConversionPatternRewriter &rewriter,
103  SmallVectorImpl<Value> &sizes,
104  SmallVectorImpl<Value> &strides,
105  Value &sizeBytes) const;
106 
107  /// Computes the size of type in bytes.
108  Value getSizeInBytes(Location loc, Type type,
109  ConversionPatternRewriter &rewriter) const;
110 
111  /// Computes total number of elements for the given shape.
113  ConversionPatternRewriter &rewriter) const;
114 
115  /// Creates and populates a canonical memref descriptor struct.
117  createMemRefDescriptor(Location loc, MemRefType memRefType,
118  Value allocatedPtr, Value alignedPtr,
119  ArrayRef<Value> sizes, ArrayRef<Value> strides,
120  ConversionPatternRewriter &rewriter) const;
121 
122  /// Copies the memory descriptor for any operands that were unranked
123  /// descriptors originally to heap-allocated memory (if toDynamic is true) or
124  /// to stack-allocated memory (otherwise). Also frees the previously used
125  /// memory (that is assumed to be heap-allocated) if toDynamic is false.
127  TypeRange origTypes,
128  SmallVectorImpl<Value> &operands,
129  bool toDynamic) const;
130 };
131 
132 /// Utility class for operation conversions targeting the LLVM dialect that
133 /// match exactly one source operation.
134 template <typename SourceOp>
136 public:
137  using OpAdaptor = typename SourceOp::Adaptor;
138 
140  PatternBenefit benefit = 1)
141  : ConvertToLLVMPattern(SourceOp::getOperationName(),
143  benefit) {}
144 
145  /// Wrappers around the RewritePattern methods that pass the derived op type.
146  void rewrite(Operation *op, ArrayRef<Value> operands,
147  ConversionPatternRewriter &rewriter) const final {
148  rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
149  rewriter);
150  }
151  LogicalResult match(Operation *op) const final {
152  return match(cast<SourceOp>(op));
153  }
156  ConversionPatternRewriter &rewriter) const final {
157  return matchAndRewrite(cast<SourceOp>(op),
158  OpAdaptor(operands, op->getAttrDictionary()),
159  rewriter);
160  }
161 
162  /// Rewrite and Match methods that operate on the SourceOp type. These must be
163  /// overridden by the derived pattern class.
164  virtual LogicalResult match(SourceOp op) const {
165  llvm_unreachable("must override match or matchAndRewrite");
166  }
167  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
168  ConversionPatternRewriter &rewriter) const {
169  llvm_unreachable("must override rewrite or matchAndRewrite");
170  }
171  virtual LogicalResult
172  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
173  ConversionPatternRewriter &rewriter) const {
174  if (failed(match(op)))
175  return failure();
176  rewrite(op, adaptor, rewriter);
177  return success();
178  }
179 
180 private:
183 };
184 
185 /// Generic implementation of one-to-one conversion from "SourceOp" to
186 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
187 /// Upholds a convention that multi-result operations get converted into an
188 /// operation returning the LLVM IR structure type, in which case individual
189 /// values must be extracted from using LLVM::ExtractValueOp before being used.
190 template <typename SourceOp, typename TargetOp>
192 public:
195 
196  /// Converts the type of the result to an LLVM type, pass operands as is,
197  /// preserve attributes.
199  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
200  ConversionPatternRewriter &rewriter) const override {
201  return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
202  adaptor.getOperands(), op->getAttrs(),
203  *this->getTypeConverter(), rewriter);
204  }
205 };
206 
207 } // namespace mlir
208 
209 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
This class implements a pattern rewriter for use with ConversionPatterns.
Base class for the conversion patterns.
TypeConverter * typeConverter
An optional type converter for use by this pattern.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:135
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
Definition: Pattern.h:155
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:139
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:172
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:146
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:167
typename SourceOp::Adaptor OpAdaptor
Definition: Pattern.h:137
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: Pattern.h:164
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Pattern.h:151
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:37
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:46
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:194
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:28
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:37
Value getNumElements(Location loc, ArrayRef< Value > shape, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given shape.
Definition: Pattern.cpp:182
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:68
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &sizeBytes) const
Computes sizes, strides and buffer size in bytes of memRefType with identity layout.
Definition: Pattern.cpp:116
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:33
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:167
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:41
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const
Create an LLVM dialect operation defining the given index constant.
Definition: Pattern.cpp:63
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:222
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:109
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:55
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:102
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:50
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:191
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition: Pattern.h:199
This class helps build Operations.
Definition: Builders.h:199
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:133
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:350
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:305
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26