MLIR  15.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
15 
16 namespace mlir {
17 
18 namespace LLVM {
19 namespace detail {
20 /// Replaces the given operation "op" with a new operation of type "targetOp"
21 /// and given operands.
22 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
23  ValueRange operands,
24  LLVMTypeConverter &typeConverter,
25  ConversionPatternRewriter &rewriter);
26 } // namespace detail
27 } // namespace LLVM
28 
29 /// Base class for operation conversions targeting the LLVM IR dialect. It
30 /// provides the conversion patterns with access to the LLVMTypeConverter and
31 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
32 /// LowerToLLVMOptions by reference meaning the references have to remain alive
33 /// during the entire pattern lifetime.
35 public:
36  ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
37  LLVMTypeConverter &typeConverter,
38  PatternBenefit benefit = 1);
39 
40 protected:
41  /// Returns the LLVM dialect.
42  LLVM::LLVMDialect &getDialect() const;
43 
44  LLVMTypeConverter *getTypeConverter() const;
45 
46  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
47  /// defined by the used type converter.
48  Type getIndexType() const;
49 
50  /// Gets the MLIR type wrapping the LLVM integer type whose bit width
51  /// corresponds to that of a LLVM pointer type.
52  Type getIntPtrType(unsigned addressSpace = 0) const;
53 
54  /// Gets the MLIR type wrapping the LLVM void type.
55  Type getVoidType() const;
56 
57  /// Get the MLIR type wrapping the LLVM i8* type.
58  Type getVoidPtrType() const;
59 
60  /// Create a constant Op producing a value of `resultType` from an index-typed
61  /// integer attribute.
62  static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
63  Type resultType, int64_t value);
64 
65  /// Create an LLVM dialect operation defining the given index constant.
66  Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
67  uint64_t value) const;
68 
69  // This is a strided getElementPtr variant that linearizes subscripts as:
70  // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
71  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
72  ValueRange indices,
73  ConversionPatternRewriter &rewriter) const;
74 
75  /// Returns if the given memref has identity maps and the element type is
76  /// convertible to LLVM.
77  bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
78 
79  /// Returns the type of a pointer to an element of the memref.
80  Type getElementPtrType(MemRefType type) const;
81 
82  /// Computes sizes, strides and buffer size in bytes of `memRefType` with
83  /// identity layout. Emits constant ops for the static sizes of `memRefType`,
84  /// and uses `dynamicSizes` for the others. Emits instructions to compute
85  /// strides and buffer size from these sizes.
86  ///
87  /// For example, memref<4x?xf32> emits:
88  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
89  /// `sizes[1]` = `dynamicSizes[0]`
90  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
91  /// `strides[0]` = `sizes[0]`
92  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
93  /// %nullptr = llvm.mlir.null : !llvm.ptr<f32>
94  /// %gep = llvm.getelementptr %nullptr[%size]
95  /// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
96  /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
97  void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
98  ValueRange dynamicSizes,
99  ConversionPatternRewriter &rewriter,
100  SmallVectorImpl<Value> &sizes,
101  SmallVectorImpl<Value> &strides,
102  Value &sizeBytes) const;
103 
104  /// Computes the size of type in bytes.
105  Value getSizeInBytes(Location loc, Type type,
106  ConversionPatternRewriter &rewriter) const;
107 
108  /// Computes total number of elements for the given shape.
110  ConversionPatternRewriter &rewriter) const;
111 
112  /// Creates and populates a canonical memref descriptor struct.
114  createMemRefDescriptor(Location loc, MemRefType memRefType,
115  Value allocatedPtr, Value alignedPtr,
116  ArrayRef<Value> sizes, ArrayRef<Value> strides,
117  ConversionPatternRewriter &rewriter) const;
118 
119  /// Copies the memory descriptor for any operands that were unranked
120  /// descriptors originally to heap-allocated memory (if toDynamic is true) or
121  /// to stack-allocated memory (otherwise). Also frees the previously used
122  /// memory (that is assumed to be heap-allocated) if toDynamic is false.
123  LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
124  TypeRange origTypes,
125  SmallVectorImpl<Value> &operands,
126  bool toDynamic) const;
127 };
128 
129 /// Utility class for operation conversions targeting the LLVM dialect that
130 /// match exactly one source operation.
131 template <typename SourceOp>
133 public:
134  using OpAdaptor = typename SourceOp::Adaptor;
135 
136  explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
137  PatternBenefit benefit = 1)
138  : ConvertToLLVMPattern(SourceOp::getOperationName(),
139  &typeConverter.getContext(), typeConverter,
140  benefit) {}
141 
142  /// Wrappers around the RewritePattern methods that pass the derived op type.
143  void rewrite(Operation *op, ArrayRef<Value> operands,
144  ConversionPatternRewriter &rewriter) const final {
145  rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
146  rewriter);
147  }
148  LogicalResult match(Operation *op) const final {
149  return match(cast<SourceOp>(op));
150  }
153  ConversionPatternRewriter &rewriter) const final {
154  return matchAndRewrite(cast<SourceOp>(op),
155  OpAdaptor(operands, op->getAttrDictionary()),
156  rewriter);
157  }
158 
159  /// Rewrite and Match methods that operate on the SourceOp type. These must be
160  /// overridden by the derived pattern class.
161  virtual LogicalResult match(SourceOp op) const {
162  llvm_unreachable("must override match or matchAndRewrite");
163  }
164  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
165  ConversionPatternRewriter &rewriter) const {
166  llvm_unreachable("must override rewrite or matchAndRewrite");
167  }
168  virtual LogicalResult
169  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
170  ConversionPatternRewriter &rewriter) const {
171  if (failed(match(op)))
172  return failure();
173  rewrite(op, adaptor, rewriter);
174  return success();
175  }
176 
177 private:
180 };
181 
182 /// Generic implementation of one-to-one conversion from "SourceOp" to
183 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
184 /// Upholds a convention that multi-result operations get converted into an
185 /// operation returning the LLVM IR structure type, in which case individual
186 /// values must be extracted from using LLVM::ExtractValueOp before being used.
187 template <typename SourceOp, typename TargetOp>
189 public:
192 
193  /// Converts the type of the result to an LLVM type, pass operands as is,
194  /// preserve attributes.
196  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
197  ConversionPatternRewriter &rewriter) const override {
198  return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
199  adaptor.getOperands(),
200  *this->getTypeConverter(), rewriter);
201  }
202 };
203 
204 } // namespace mlir
205 
206 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:143
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: Pattern.h:161
typename cf::AssertOp ::Adaptor OpAdaptor
Definition: Pattern.h:134
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:188
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:136
Base class for the conversion patterns.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition: Pattern.h:196
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:164
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:195
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: SPIRVOps.cpp:972
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:169
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:751
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Pattern.h:148
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:310
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:34
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
Definition: Pattern.h:152
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.