MLIR  18.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
15 
16 namespace mlir {
17 class CallOpInterface;
18 
19 namespace LLVM {
20 namespace detail {
21 /// Replaces the given operation "op" with a new operation of type "targetOp"
22 /// and given operands.
23 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
24  ValueRange operands,
25  ArrayRef<NamedAttribute> targetAttrs,
26  const LLVMTypeConverter &typeConverter,
27  ConversionPatternRewriter &rewriter);
28 
29 } // namespace detail
30 } // namespace LLVM
31 
32 /// Base class for operation conversions targeting the LLVM IR dialect. It
33 /// provides the conversion patterns with access to the LLVMTypeConverter and
34 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
35 /// LowerToLLVMOptions by reference meaning the references have to remain alive
36 /// during the entire pattern lifetime.
38 public:
39  ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
41  PatternBenefit benefit = 1);
42 
43 protected:
44  /// Returns the LLVM dialect.
45  LLVM::LLVMDialect &getDialect() const;
46 
47  const LLVMTypeConverter *getTypeConverter() const;
48 
49  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
50  /// defined by the used type converter.
51  Type getIndexType() const;
52 
53  /// Gets the MLIR type wrapping the LLVM integer type whose bit width
54  /// corresponds to that of a LLVM pointer type.
55  Type getIntPtrType(unsigned addressSpace = 0) const;
56 
57  /// Gets the MLIR type wrapping the LLVM void type.
58  Type getVoidType() const;
59 
60  /// Get the MLIR type wrapping the LLVM i8* type.
61  Type getVoidPtrType() const;
62 
63  /// Create a constant Op producing a value of `resultType` from an index-typed
64  /// integer attribute.
65  static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
66  Type resultType, int64_t value);
67 
68  // This is a strided getElementPtr variant that linearizes subscripts as:
69  // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
70  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
71  ValueRange indices,
72  ConversionPatternRewriter &rewriter) const;
73 
74  /// Returns if the given memref has identity maps and the element type is
75  /// convertible to LLVM.
76  bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
77 
78  /// Returns the type of a pointer to an element of the memref.
79  Type getElementPtrType(MemRefType type) const;
80 
81  /// Computes sizes, strides and buffer size of `memRefType` with identity
82  /// layout. Emits constant ops for the static sizes of `memRefType`, and uses
83  /// `dynamicSizes` for the others. Emits instructions to compute strides and
84  /// buffer size from these sizes.
85  ///
86  /// For example, memref<4x?xf32> with `sizeInBytes = true` emits:
87  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
88  /// `sizes[1]` = `dynamicSizes[0]`
89  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
90  /// `strides[0]` = `sizes[0]`
91  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
92  /// %nullptr = llvm.mlir.zero : !llvm.ptr
93  /// %gep = llvm.getelementptr %nullptr[%size]
94  /// : (!llvm.ptr, i64) -> !llvm.ptr, f32
95  /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64
96  ///
97  /// If `sizeInBytes = false`, memref<4x?xf32> emits:
98  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
99  /// `sizes[1]` = `dynamicSizes[0]`
100  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
101  /// `strides[0]` = `sizes[0]`
102  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
103  void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
104  ValueRange dynamicSizes,
105  ConversionPatternRewriter &rewriter,
106  SmallVectorImpl<Value> &sizes,
107  SmallVectorImpl<Value> &strides, Value &size,
108  bool sizeInBytes = true) const;
109 
110  /// Computes the size of type in bytes.
111  Value getSizeInBytes(Location loc, Type type,
112  ConversionPatternRewriter &rewriter) const;
113 
114  /// Computes total number of elements for the given MemRef and dynamicSizes.
115  Value getNumElements(Location loc, MemRefType memRefType,
116  ValueRange dynamicSizes,
117  ConversionPatternRewriter &rewriter) const;
118 
119  /// Creates and populates a canonical memref descriptor struct.
121  createMemRefDescriptor(Location loc, MemRefType memRefType,
122  Value allocatedPtr, Value alignedPtr,
123  ArrayRef<Value> sizes, ArrayRef<Value> strides,
124  ConversionPatternRewriter &rewriter) const;
125 
126  /// Copies the memory descriptor for any operands that were unranked
127  /// descriptors originally to heap-allocated memory (if toDynamic is true) or
128  /// to stack-allocated memory (otherwise). Also frees the previously used
129  /// memory (that is assumed to be heap-allocated) if toDynamic is false.
131  TypeRange origTypes,
132  SmallVectorImpl<Value> &operands,
133  bool toDynamic) const;
134 };
135 
136 /// Utility class for operation conversions targeting the LLVM dialect that
137 /// match exactly one source operation.
138 template <typename SourceOp>
140 public:
141  using OpAdaptor = typename SourceOp::Adaptor;
142 
144  PatternBenefit benefit = 1)
145  : ConvertToLLVMPattern(SourceOp::getOperationName(),
147  benefit) {}
148 
149  /// Wrappers around the RewritePattern methods that pass the derived op type.
150  void rewrite(Operation *op, ArrayRef<Value> operands,
151  ConversionPatternRewriter &rewriter) const final {
152  rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
153  rewriter);
154  }
155  LogicalResult match(Operation *op) const final {
156  return match(cast<SourceOp>(op));
157  }
160  ConversionPatternRewriter &rewriter) const final {
161  return matchAndRewrite(cast<SourceOp>(op),
162  OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
163  }
164 
165  /// Rewrite and Match methods that operate on the SourceOp type. These must be
166  /// overridden by the derived pattern class.
167  virtual LogicalResult match(SourceOp op) const {
168  llvm_unreachable("must override match or matchAndRewrite");
169  }
170  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
171  ConversionPatternRewriter &rewriter) const {
172  llvm_unreachable("must override rewrite or matchAndRewrite");
173  }
174  virtual LogicalResult
175  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
176  ConversionPatternRewriter &rewriter) const {
177  if (failed(match(op)))
178  return failure();
179  rewrite(op, adaptor, rewriter);
180  return success();
181  }
182 
183 private:
186 };
187 
188 /// Generic implementation of one-to-one conversion from "SourceOp" to
189 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
190 /// Upholds a convention that multi-result operations get converted into an
191 /// operation returning the LLVM IR structure type, in which case individual
192 /// values must be extracted from using LLVM::ExtractValueOp before being used.
193 template <typename SourceOp, typename TargetOp>
195 public:
198 
199  /// Converts the type of the result to an LLVM type, pass operands as is,
200  /// preserve attributes.
202  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
203  ConversionPatternRewriter &rewriter) const override {
204  return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
205  adaptor.getOperands(), op->getAttrs(),
206  *this->getTypeConverter(), rewriter);
207  }
208 };
209 
210 } // namespace mlir
211 
212 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
This class implements a pattern rewriter for use with ConversionPatterns.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:143
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
Definition: Pattern.h:159
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:175
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:150
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:170
typename SourceOp::Adaptor OpAdaptor
Definition: Pattern.h:141
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: Pattern.h:167
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Pattern.h:155
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:37
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:218
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition: Pattern.cpp:186
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:171
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:247
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:107
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:100
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:49
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:194
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition: Pattern.h:202
This class helps build Operations.
Definition: Builders.h:206
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:133
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:335
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26