MLIR  16.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
21 
23  MLIRContext *context,
24  LLVMTypeConverter &typeConverter,
25  PatternBenefit benefit)
26  : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
27 
29  return static_cast<LLVMTypeConverter *>(
31 }
32 
33 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
34  return *getTypeConverter()->getDialect();
35 }
36 
38  return getTypeConverter()->getIndexType();
39 }
40 
41 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
42  return IntegerType::get(&getTypeConverter()->getContext(),
43  getTypeConverter()->getPointerBitwidth(addressSpace));
44 }
45 
47  return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
48 }
49 
51  return LLVM::LLVMPointerType::get(
52  IntegerType::get(&getTypeConverter()->getContext(), 8));
53 }
54 
56  Location loc,
57  Type resultType,
58  int64_t value) {
59  return builder.create<LLVM::ConstantOp>(loc, resultType,
60  builder.getIndexAttr(value));
61 }
62 
64  ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
65  return createIndexAttrConstant(builder, loc, getIndexType(), value);
66 }
67 
69  Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
70  ConversionPatternRewriter &rewriter) const {
71 
72  int64_t offset;
74  auto successStrides = getStridesAndOffset(type, strides, offset);
75  assert(succeeded(successStrides) && "unexpected non-strided memref");
76  (void)successStrides;
77 
78  MemRefDescriptor memRefDescriptor(memRefDesc);
79  Value base = memRefDescriptor.alignedPtr(rewriter, loc);
80 
81  Value index;
82  if (offset != 0) // Skip if offset is zero.
83  index = ShapedType::isDynamic(offset)
84  ? memRefDescriptor.offset(rewriter, loc)
85  : createIndexConstant(rewriter, loc, offset);
86 
87  for (int i = 0, e = indices.size(); i < e; ++i) {
88  Value increment = indices[i];
89  if (strides[i] != 1) { // Skip if stride is 1.
90  Value stride = ShapedType::isDynamic(strides[i])
91  ? memRefDescriptor.stride(rewriter, loc, i)
92  : createIndexConstant(rewriter, loc, strides[i]);
93  increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
94  }
95  index =
96  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
97  }
98 
99  Type elementPtrType = memRefDescriptor.getElementPtrType();
100  return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
101  : base;
102 }
103 
104 // Check if the MemRefType `type` is supported by the lowering. We currently
105 // only support memrefs with identity maps.
107  MemRefType type) const {
108  if (!typeConverter->convertType(type.getElementType()))
109  return false;
110  return type.getLayout().isIdentity();
111 }
112 
114  auto elementType = type.getElementType();
115  auto structElementType = typeConverter->convertType(elementType);
116  return LLVM::LLVMPointerType::get(structElementType,
117  type.getMemorySpaceAsInt());
118 }
119 
121  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
123  SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
124  assert(isConvertibleAndHasIdentityMaps(memRefType) &&
125  "layout maps must have been normalized away");
126  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
127  static_cast<ssize_t>(dynamicSizes.size()) &&
128  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
129 
130  sizes.reserve(memRefType.getRank());
131  unsigned dynamicIndex = 0;
132  for (int64_t size : memRefType.getShape()) {
133  sizes.push_back(size == ShapedType::kDynamic
134  ? dynamicSizes[dynamicIndex++]
135  : createIndexConstant(rewriter, loc, size));
136  }
137 
138  // Strides: iterate sizes in reverse order and multiply.
139  int64_t stride = 1;
140  Value runningStride = createIndexConstant(rewriter, loc, 1);
141  strides.resize(memRefType.getRank());
142  for (auto i = memRefType.getRank(); i-- > 0;) {
143  strides[i] = runningStride;
144 
145  int64_t size = memRefType.getShape()[i];
146  if (size == 0)
147  continue;
148  bool useSizeAsStride = stride == 1;
149  if (size == ShapedType::kDynamic)
150  stride = ShapedType::kDynamic;
151  if (stride != ShapedType::kDynamic)
152  stride *= size;
153 
154  if (useSizeAsStride)
155  runningStride = sizes[i];
156  else if (stride == ShapedType::kDynamic)
157  runningStride =
158  rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
159  else
160  runningStride = createIndexConstant(rewriter, loc, stride);
161  }
162 
163  // Buffer size in bytes.
164  Type elementPtrType = getElementPtrType(memRefType);
165  Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
166  Value gepPtr =
167  rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, runningStride);
168  sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
169 }
170 
172  Location loc, Type type, ConversionPatternRewriter &rewriter) const {
173  // Compute the size of an individual element. This emits the MLIR equivalent
174  // of the following sizeof(...) implementation in LLVM IR:
175  // %0 = getelementptr %elementType* null, %indexType 1
176  // %1 = ptrtoint %elementType* %0 to %indexType
177  // which is a common pattern of getting the size of a type in bytes.
178  auto convertedPtrType =
179  LLVM::LLVMPointerType::get(typeConverter->convertType(type));
180  auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
181  auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, nullPtr,
183  return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
184 }
185 
187  Location loc, ArrayRef<Value> shape,
188  ConversionPatternRewriter &rewriter) const {
189  // Compute the total number of memref elements.
190  Value numElements =
191  shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
192  for (unsigned i = 1, e = shape.size(); i < e; ++i)
193  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
194  return numElements;
195 }
196 
197 /// Creates and populates the memref descriptor struct given all its fields.
199  Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
200  ArrayRef<Value> sizes, ArrayRef<Value> strides,
201  ConversionPatternRewriter &rewriter) const {
202  auto structType = typeConverter->convertType(memRefType);
203  auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
204 
205  // Field 1: Allocated pointer, used for malloc/free.
206  memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
207 
208  // Field 2: Actual aligned pointer to payload.
209  memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
210 
211  // Field 3: Offset in aligned pointer.
212  memRefDescriptor.setOffset(rewriter, loc,
213  createIndexConstant(rewriter, loc, 0));
214 
215  // Fields 4: Sizes.
216  for (const auto &en : llvm::enumerate(sizes))
217  memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
218 
219  // Field 5: Strides.
220  for (const auto &en : llvm::enumerate(strides))
221  memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
222 
223  return memRefDescriptor;
224 }
225 
227  OpBuilder &builder, Location loc, TypeRange origTypes,
228  SmallVectorImpl<Value> &operands, bool toDynamic) const {
229  assert(origTypes.size() == operands.size() &&
230  "expected as may original types as operands");
231 
232  // Find operands of unranked memref type and store them.
234  for (unsigned i = 0, e = operands.size(); i < e; ++i)
235  if (origTypes[i].isa<UnrankedMemRefType>())
236  unrankedMemrefs.emplace_back(operands[i]);
237 
238  if (unrankedMemrefs.empty())
239  return success();
240 
241  // Compute allocation sizes.
242  SmallVector<Value, 4> sizes;
244  unrankedMemrefs, sizes);
245 
246  // Get frequently used types.
247  MLIRContext *context = builder.getContext();
248  Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
249  auto i1Type = IntegerType::get(context, 1);
250  Type indexType = getTypeConverter()->getIndexType();
251 
252  // Find the malloc and free, or declare them if necessary.
253  auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
254  LLVM::LLVMFuncOp freeFunc, mallocFunc;
255  if (toDynamic)
256  mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
257  if (!toDynamic)
258  freeFunc = LLVM::lookupOrCreateFreeFn(module);
259 
260  // Initialize shared constants.
261  Value zero =
262  builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
263 
264  unsigned unrankedMemrefPos = 0;
265  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
266  Type type = origTypes[i];
267  if (!type.isa<UnrankedMemRefType>())
268  continue;
269  Value allocationSize = sizes[unrankedMemrefPos++];
270  UnrankedMemRefDescriptor desc(operands[i]);
271 
272  // Allocate memory, copy, and free the source if necessary.
273  Value memory =
274  toDynamic
275  ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
276  .getResult()
277  : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
278  /*alignment=*/0);
279  Value source = desc.memRefDescPtr(builder, loc);
280  builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
281  if (!toDynamic)
282  builder.create<LLVM::CallOp>(loc, freeFunc, source);
283 
284  // Create a new descriptor. The same descriptor can be returned multiple
285  // times, attempting to modify its pointer can lead to memory leaks
286  // (allocated twice and overwritten) or double frees (the caller does not
287  // know if the descriptor points to the same memory).
288  Type descriptorType = getTypeConverter()->convertType(type);
289  if (!descriptorType)
290  return failure();
291  auto updatedDesc =
292  UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
293  Value rank = desc.rank(builder, loc);
294  updatedDesc.setRank(builder, loc, rank);
295  updatedDesc.setMemRefDescPtr(builder, loc, memory);
296 
297  operands[i] = updatedDesc;
298  }
299 
300  return success();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Detail methods
305 //===----------------------------------------------------------------------===//
306 
307 /// Replaces the given operation "op" with a new operation of type "targetOp"
308 /// and given operands.
310  Operation *op, StringRef targetOp, ValueRange operands,
311  ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
312  ConversionPatternRewriter &rewriter) {
313  unsigned numResults = op->getNumResults();
314 
315  SmallVector<Type> resultTypes;
316  if (numResults != 0) {
317  resultTypes.push_back(
318  typeConverter.packFunctionResults(op->getResultTypes()));
319  if (!resultTypes.back())
320  return failure();
321  }
322 
323  // Create the operation through state since we don't know its C++ type.
324  Operation *newOp =
325  rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
326  resultTypes, targetAttrs);
327 
328  // If the operation produced 0 or 1 result, return them immediately.
329  if (numResults == 0)
330  return rewriter.eraseOp(op), success();
331  if (numResults == 1)
332  return rewriter.replaceOp(op, newOp->getResult(0)), success();
333 
334  // Otherwise, it had been converted to an operation producing a structure.
335  // Extract individual results from the structure and return them as list.
336  SmallVector<Value, 4> results;
337  results.reserve(numResults);
338  for (unsigned i = 0; i < numResults; ++i) {
339  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
340  op->getLoc(), newOp->getResult(0), i));
341  }
342  rewriter.replaceOp(op, results);
343  return success();
344 }
static constexpr const bool value
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:101
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:243
MLIRContext * getContext() const
Definition: Builders.h:54
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
TypeConverter * typeConverter
An optional type converter for use by this pattern.
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:46
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:198
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:28
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:37
Value getNumElements(Location loc, ArrayRef< Value > shape, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given shape.
Definition: Pattern.cpp:186
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:68
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &sizeBytes) const
Computes sizes, strides and buffer size in bytes of memRefType with identity layout.
Definition: Pattern.cpp:120
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:33
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:171
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:41
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const
Create an LLVM dialect operation defining the given index constant.
Definition: Pattern.cpp:63
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:226
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:113
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:55
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:106
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:50
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
Type packFunctionResults(TypeRange types)
Convert a non-empty list of types to be returned from a function into a supported LLVM IR type.
LLVM::LLVMDialect * getDialect()
Returns the LLVM dialect.
Definition: TypeConverter.h:79
Type getIndexType()
Gets the LLVM representation of the index type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
Definition: Builders.h:198
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:397
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
result_type_range getResultTypes()
Definition: Operation.h:345
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:132
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isa() const
Definition: Types.h:260
Value rank(OpBuilder &builder, Location loc)
Builds IR extracting the rank from the descriptor.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Value memRefDescPtr(OpBuilder &builder, Location loc)
Builds IR extracting ranked memref descriptor ptr.
static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:309
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26