MLIR  16.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
15 
16 namespace mlir {
17 class CallOpInterface;
18 
19 namespace LLVM {
20 namespace detail {
21 /// Replaces the given operation "op" with a new operation of type "targetOp"
22 /// and given operands.
23 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
24  ValueRange operands,
25  LLVMTypeConverter &typeConverter,
26  ConversionPatternRewriter &rewriter);
27 } // namespace detail
28 } // namespace LLVM
29 
30 /// Base class for operation conversions targeting the LLVM IR dialect. It
31 /// provides the conversion patterns with access to the LLVMTypeConverter and
32 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
33 /// LowerToLLVMOptions by reference meaning the references have to remain alive
34 /// during the entire pattern lifetime.
36 public:
37  ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
38  LLVMTypeConverter &typeConverter,
39  PatternBenefit benefit = 1);
40 
41 protected:
42  /// Returns the LLVM dialect.
43  LLVM::LLVMDialect &getDialect() const;
44 
45  LLVMTypeConverter *getTypeConverter() const;
46 
47  /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
48  /// defined by the used type converter.
49  Type getIndexType() const;
50 
51  /// Gets the MLIR type wrapping the LLVM integer type whose bit width
52  /// corresponds to that of a LLVM pointer type.
53  Type getIntPtrType(unsigned addressSpace = 0) const;
54 
55  /// Gets the MLIR type wrapping the LLVM void type.
56  Type getVoidType() const;
57 
58  /// Get the MLIR type wrapping the LLVM i8* type.
59  Type getVoidPtrType() const;
60 
61  /// Create a constant Op producing a value of `resultType` from an index-typed
62  /// integer attribute.
63  static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
64  Type resultType, int64_t value);
65 
66  /// Create an LLVM dialect operation defining the given index constant.
67  Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
68  uint64_t value) const;
69 
70  // This is a strided getElementPtr variant that linearizes subscripts as:
71  // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
72  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
73  ValueRange indices,
74  ConversionPatternRewriter &rewriter) const;
75 
76  /// Returns if the given memref has identity maps and the element type is
77  /// convertible to LLVM.
78  bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
79 
80  /// Returns the type of a pointer to an element of the memref.
81  Type getElementPtrType(MemRefType type) const;
82 
83  /// Computes sizes, strides and buffer size in bytes of `memRefType` with
84  /// identity layout. Emits constant ops for the static sizes of `memRefType`,
85  /// and uses `dynamicSizes` for the others. Emits instructions to compute
86  /// strides and buffer size from these sizes.
87  ///
88  /// For example, memref<4x?xf32> emits:
89  /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
90  /// `sizes[1]` = `dynamicSizes[0]`
91  /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
92  /// `strides[0]` = `sizes[0]`
93  /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
94  /// %nullptr = llvm.mlir.null : !llvm.ptr<f32>
95  /// %gep = llvm.getelementptr %nullptr[%size]
96  /// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
97  /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
98  void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
99  ValueRange dynamicSizes,
100  ConversionPatternRewriter &rewriter,
101  SmallVectorImpl<Value> &sizes,
102  SmallVectorImpl<Value> &strides,
103  Value &sizeBytes) const;
104 
105  /// Computes the size of type in bytes.
106  Value getSizeInBytes(Location loc, Type type,
107  ConversionPatternRewriter &rewriter) const;
108 
109  /// Computes total number of elements for the given shape.
111  ConversionPatternRewriter &rewriter) const;
112 
113  /// Creates and populates a canonical memref descriptor struct.
115  createMemRefDescriptor(Location loc, MemRefType memRefType,
116  Value allocatedPtr, Value alignedPtr,
117  ArrayRef<Value> sizes, ArrayRef<Value> strides,
118  ConversionPatternRewriter &rewriter) const;
119 
120  /// Copies the memory descriptor for any operands that were unranked
121  /// descriptors originally to heap-allocated memory (if toDynamic is true) or
122  /// to stack-allocated memory (otherwise). Also frees the previously used
123  /// memory (that is assumed to be heap-allocated) if toDynamic is false.
124  LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
125  TypeRange origTypes,
126  SmallVectorImpl<Value> &operands,
127  bool toDynamic) const;
128 };
129 
130 /// Utility class for operation conversions targeting the LLVM dialect that
131 /// match exactly one source operation.
132 template <typename SourceOp>
134 public:
135  using OpAdaptor = typename SourceOp::Adaptor;
136 
137  explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
138  PatternBenefit benefit = 1)
139  : ConvertToLLVMPattern(SourceOp::getOperationName(),
140  &typeConverter.getContext(), typeConverter,
141  benefit) {}
142 
143  /// Wrappers around the RewritePattern methods that pass the derived op type.
144  void rewrite(Operation *op, ArrayRef<Value> operands,
145  ConversionPatternRewriter &rewriter) const final {
146  rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
147  rewriter);
148  }
149  LogicalResult match(Operation *op) const final {
150  return match(cast<SourceOp>(op));
151  }
154  ConversionPatternRewriter &rewriter) const final {
155  return matchAndRewrite(cast<SourceOp>(op),
156  OpAdaptor(operands, op->getAttrDictionary()),
157  rewriter);
158  }
159 
160  /// Rewrite and Match methods that operate on the SourceOp type. These must be
161  /// overridden by the derived pattern class.
162  virtual LogicalResult match(SourceOp op) const {
163  llvm_unreachable("must override match or matchAndRewrite");
164  }
165  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
166  ConversionPatternRewriter &rewriter) const {
167  llvm_unreachable("must override rewrite or matchAndRewrite");
168  }
169  virtual LogicalResult
170  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
171  ConversionPatternRewriter &rewriter) const {
172  if (failed(match(op)))
173  return failure();
174  rewrite(op, adaptor, rewriter);
175  return success();
176  }
177 
178 private:
181 };
182 
183 /// Generic implementation of one-to-one conversion from "SourceOp" to
184 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
185 /// Upholds a convention that multi-result operations get converted into an
186 /// operation returning the LLVM IR structure type, in which case individual
187 /// values must be extracted from using LLVM::ExtractValueOp before being used.
188 template <typename SourceOp, typename TargetOp>
190 public:
193 
194  /// Converts the type of the result to an LLVM type, pass operands as is,
195  /// preserve attributes.
197  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
198  ConversionPatternRewriter &rewriter) const override {
199  return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
200  adaptor.getOperands(),
201  *this->getTypeConverter(), rewriter);
202  }
203 };
204 
205 } // namespace mlir
206 
207 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:133
void rewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:144
virtual LogicalResult match(SourceOp op) const
Rewrite and Match methods that operate on the SourceOp type.
Definition: Pattern.h:162
typename cf::AssertOp ::Adaptor OpAdaptor
Definition: Pattern.h:135
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:189
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:137
Base class for the conversion patterns.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Converts the type of the result to an LLVM type, pass operands as is, preserve attributes.
Definition: Pattern.h:197
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:165
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: SPIRVOps.cpp:983
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const
Definition: Pattern.h:170
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:854
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Pattern.h:149
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:309
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:35
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Hook for derived classes to implement combined matching and rewriting.
Definition: Pattern.h:153
This class helps build Operations.
Definition: Builders.h:196
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345