MLIR  19.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
16 
17 namespace mlir {
18 class CallOpInterface;
19 
20 namespace LLVM {
21 namespace detail {
22 /// Handle generically setting flags as native properties on LLVM operations.
23 void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
24 
25 /// Replaces the given operation "op" with a new operation of type "targetOp"
26 /// and given operands.
28  Operation *op, StringRef targetOp, ValueRange operands,
29  ArrayRef<NamedAttribute> targetAttrs,
30  const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
31  IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
32 
33 } // namespace detail
34 } // namespace LLVM
35 
36 /// Base class for operation conversions targeting the LLVM IR dialect. It
37 /// provides the conversion patterns with access to the LLVMTypeConverter and
38 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
39 /// LowerToLLVMOptions by reference meaning the references have to remain alive
40 /// during the entire pattern lifetime.
42 public:
43  ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
45  PatternBenefit benefit = 1);
46 
47 protected:
48  /// Returns the LLVM dialect.
49  LLVM::LLVMDialect &getDialect() const;
50 
51  const LLVMTypeConverter *getTypeConverter() const;
52 
53  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
54  /// defined by the used type converter.
55  Type getIndexType() const;
56 
57  /// Gets the MLIR type wrapping the LLVM integer type whose bit width
58  /// corresponds to that of a LLVM pointer type.
59  Type getIntPtrType(unsigned addressSpace = 0) const;
60 
61  /// Gets the MLIR type wrapping the LLVM void type.
62  Type getVoidType() const;
63 
64  /// Get the MLIR type wrapping the LLVM i8* type.
65  Type getVoidPtrType() const;
66 
67  /// Create a constant Op producing a value of `resultType` from an index-typed
68  /// integer attribute.
69  static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
70  Type resultType, int64_t value);
71 
72  // This is a strided getElementPtr variant that linearizes subscripts as:
73  // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
74  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
75  ValueRange indices,
76  ConversionPatternRewriter &rewriter) const;
77 
78  /// Returns if the given memref has identity maps and the element type is
79  /// convertible to LLVM.
80  bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
81 
82  /// Returns the type of a pointer to an element of the memref.
83  Type getElementPtrType(MemRefType type) const;
84 
85  /// Computes sizes, strides and buffer size of `memRefType` with identity
86  /// layout. Emits constant ops for the static sizes of `memRefType`, and uses
87  /// `dynamicSizes` for the others. Emits instructions to compute strides and
88  /// buffer size from these sizes.
89  ///
90  /// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
91  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
92  /// `sizes[1]` = `dynamicSizes[0]`
93  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
94  /// `strides[0]` = `sizes[0]`
95  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
96  /// %nullptr = llvm.mlir.zero : !llvm.ptr
97  /// %gep = llvm.getelementptr %nullptr[%size]
98  /// : (!llvm.ptr, i64) -> !llvm.ptr, f32
99  /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64
100  ///
101  /// If `sizeInBytes = false`, memref<4x?xf32> emits:
102  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
103  /// `sizes[1]` = `dynamicSizes[0]`
104  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
105  /// `strides[0]` = `sizes[0]`
106  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
107  void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
108  ValueRange dynamicSizes,
109  ConversionPatternRewriter &rewriter,
110  SmallVectorImpl<Value> &sizes,
111  SmallVectorImpl<Value> &strides, Value &size,
112  bool sizeInBytes = true) const;
113 
114  /// Computes the size of type in bytes.
115  Value getSizeInBytes(Location loc, Type type,
116  ConversionPatternRewriter &rewriter) const;
117 
118  /// Computes total number of elements for the given MemRef and dynamicSizes.
119  Value getNumElements(Location loc, MemRefType memRefType,
120  ValueRange dynamicSizes,
121  ConversionPatternRewriter &rewriter) const;
122 
123  /// Creates and populates a canonical memref descriptor struct.
125  createMemRefDescriptor(Location loc, MemRefType memRefType,
126  Value allocatedPtr, Value alignedPtr,
127  ArrayRef<Value> sizes, ArrayRef<Value> strides,
128  ConversionPatternRewriter &rewriter) const;
129 
130  /// Copies the memory descriptor for any operands that were unranked
131  /// descriptors originally to heap-allocated memory (if toDynamic is true) or
132  /// to stack-allocated memory (otherwise). Also frees the previously used
133  /// memory (that is assumed to be heap-allocated) if toDynamic is false.
135  TypeRange origTypes,
136  SmallVectorImpl<Value> &operands,
137  bool toDynamic) const;
138 };
139 
140 /// Utility class for operation conversions targeting the LLVM dialect that
141 /// match exactly one source operation.
142 template <typename SourceOp>
144 public:
145  using OpAdaptor = typename SourceOp::Adaptor;
146 
148  PatternBenefit benefit = 1)
149  : ConvertToLLVMPattern(SourceOp::getOperationName(),
151  benefit) {}
152 
153  /// Wrappers around the RewritePattern methods that pass the derived op type.
154  void rewrite(Operation *op, ArrayRef<Value> operands,
155  ConversionPatternRewriter &rewriter) const final {
156  rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157  rewriter);
158  }
159  LogicalResult match(Operation *op) const final {
160  return match(cast<SourceOp>(op));
161  }
164  ConversionPatternRewriter &rewriter) const final {
165  return matchAndRewrite(cast<SourceOp>(op),
166  OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
167  }
168 
169  /// Rewrite and Match methods that operate on the SourceOp type. These must be
170  /// overridden by the derived pattern class.
171  virtual LogicalResult match(SourceOp op) const {
172  llvm_unreachable("must override match or matchAndRewrite");
173  }
174  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
175  ConversionPatternRewriter &rewriter) const {
176  llvm_unreachable("must override rewrite or matchAndRewrite");
177  }
178  virtual LogicalResult
179  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
180  ConversionPatternRewriter &rewriter) const {
181  if (failed(match(op)))
182  return failure();
183  rewrite(op, adaptor, rewriter);
184  return success();
185  }
186 
187 private:
190 };
191 
192 /// Generic implementation of one-to-one conversion from "SourceOp" to
193 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
194 /// Upholds a convention that multi-result operations get converted into an
195 /// operation returning the LLVM IR structure type, in which case individual
196 /// values must be extracted from using LLVM::ExtractValueOp before being used.
197 template <typename SourceOp, typename TargetOp>
199 public:
202 
203  /// Converts the type of the result to an LLVM type, pass operands as is,
204  /// preserve attributes.
206  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
207  ConversionPatternRewriter &rewriter) const override {
208  return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
209  adaptor.getOperands(), op->getAttrs(),
210  *this->getTypeConverter(), rewriter);
211  }
212 };
213 
214 } // namespace mlir
215 
216 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
This class implements a pattern rewriter for use with ConversionPatterns.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:147
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
Definition: Pattern.h:163
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:179
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:154
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:174
typename SourceOp::Adaptor OpAdaptor
Definition: Pattern.h:145
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: Pattern.h:171
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Pattern.h:159
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:41
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:218
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition: Pattern.cpp:186
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:171
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:247
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:107
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:100
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:49
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:198
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition: Pattern.h:206
This class helps build Operations.
Definition: Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:340
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition: Pattern.cpp:332
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26