MLIR  15.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
21 
23  MLIRContext *context,
24  LLVMTypeConverter &typeConverter,
25  PatternBenefit benefit)
26  : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
27 
29  return static_cast<LLVMTypeConverter *>(
31 }
32 
33 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
34  return *getTypeConverter()->getDialect();
35 }
36 
38  return getTypeConverter()->getIndexType();
39 }
40 
41 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
42  return IntegerType::get(&getTypeConverter()->getContext(),
43  getTypeConverter()->getPointerBitwidth(addressSpace));
44 }
45 
47  return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
48 }
49 
52  IntegerType::get(&getTypeConverter()->getContext(), 8));
53 }
54 
56  Location loc,
57  Type resultType,
58  int64_t value) {
59  return builder.create<LLVM::ConstantOp>(
60  loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
61 }
62 
64  ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
65  return createIndexAttrConstant(builder, loc, getIndexType(), value);
66 }
67 
69  Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
70  ConversionPatternRewriter &rewriter) const {
71 
72  int64_t offset;
74  auto successStrides = getStridesAndOffset(type, strides, offset);
75  assert(succeeded(successStrides) && "unexpected non-strided memref");
76  (void)successStrides;
77 
78  MemRefDescriptor memRefDescriptor(memRefDesc);
79  Value base = memRefDescriptor.alignedPtr(rewriter, loc);
80 
81  Value index;
82  if (offset != 0) // Skip if offset is zero.
83  index = ShapedType::isDynamicStrideOrOffset(offset)
84  ? memRefDescriptor.offset(rewriter, loc)
85  : createIndexConstant(rewriter, loc, offset);
86 
87  for (int i = 0, e = indices.size(); i < e; ++i) {
88  Value increment = indices[i];
89  if (strides[i] != 1) { // Skip if stride is 1.
90  Value stride = ShapedType::isDynamicStrideOrOffset(strides[i])
91  ? memRefDescriptor.stride(rewriter, loc, i)
92  : createIndexConstant(rewriter, loc, strides[i]);
93  increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
94  }
95  index =
96  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
97  }
98 
99  Type elementPtrType = memRefDescriptor.getElementPtrType();
100  return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
101  : base;
102 }
103 
104 // Check if the MemRefType `type` is supported by the lowering. We currently
105 // only support memrefs with identity maps.
107  MemRefType type) const {
108  if (!typeConverter->convertType(type.getElementType()))
109  return false;
110  return type.getLayout().isIdentity();
111 }
112 
114  auto elementType = type.getElementType();
115  auto structElementType = typeConverter->convertType(elementType);
116  return LLVM::LLVMPointerType::get(structElementType,
117  type.getMemorySpaceAsInt());
118 }
119 
121  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
123  SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
124  assert(isConvertibleAndHasIdentityMaps(memRefType) &&
125  "layout maps must have been normalized away");
126  assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
127  static_cast<ssize_t>(dynamicSizes.size()) &&
128  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
129 
130  sizes.reserve(memRefType.getRank());
131  unsigned dynamicIndex = 0;
132  for (int64_t size : memRefType.getShape()) {
133  sizes.push_back(size == ShapedType::kDynamicSize
134  ? dynamicSizes[dynamicIndex++]
135  : createIndexConstant(rewriter, loc, size));
136  }
137 
138  // Strides: iterate sizes in reverse order and multiply.
139  int64_t stride = 1;
140  Value runningStride = createIndexConstant(rewriter, loc, 1);
141  strides.resize(memRefType.getRank());
142  for (auto i = memRefType.getRank(); i-- > 0;) {
143  strides[i] = runningStride;
144 
145  int64_t size = memRefType.getShape()[i];
146  if (size == 0)
147  continue;
148  bool useSizeAsStride = stride == 1;
149  if (size == ShapedType::kDynamicSize)
150  stride = ShapedType::kDynamicSize;
151  if (stride != ShapedType::kDynamicSize)
152  stride *= size;
153 
154  if (useSizeAsStride)
155  runningStride = sizes[i];
156  else if (stride == ShapedType::kDynamicSize)
157  runningStride =
158  rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
159  else
160  runningStride = createIndexConstant(rewriter, loc, stride);
161  }
162 
163  // Buffer size in bytes.
164  Type elementPtrType = getElementPtrType(memRefType);
165  Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
166  Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
167  ArrayRef<Value>{runningStride});
168  sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
169 }
170 
172  Location loc, Type type, ConversionPatternRewriter &rewriter) const {
173  // Compute the size of an individual element. This emits the MLIR equivalent
174  // of the following sizeof(...) implementation in LLVM IR:
175  // %0 = getelementptr %elementType* null, %indexType 1
176  // %1 = ptrtoint %elementType* %0 to %indexType
177  // which is a common pattern of getting the size of a type in bytes.
178  auto convertedPtrType =
180  auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
181  auto gep = rewriter.create<LLVM::GEPOp>(
182  loc, convertedPtrType, nullPtr,
183  ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)});
184  return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
185 }
186 
188  Location loc, ArrayRef<Value> shape,
189  ConversionPatternRewriter &rewriter) const {
190  // Compute the total number of memref elements.
191  Value numElements =
192  shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
193  for (unsigned i = 1, e = shape.size(); i < e; ++i)
194  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
195  return numElements;
196 }
197 
198 /// Creates and populates the memref descriptor struct given all its fields.
200  Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
201  ArrayRef<Value> sizes, ArrayRef<Value> strides,
202  ConversionPatternRewriter &rewriter) const {
203  auto structType = typeConverter->convertType(memRefType);
204  auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
205 
206  // Field 1: Allocated pointer, used for malloc/free.
207  memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
208 
209  // Field 2: Actual aligned pointer to payload.
210  memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
211 
212  // Field 3: Offset in aligned pointer.
213  memRefDescriptor.setOffset(rewriter, loc,
214  createIndexConstant(rewriter, loc, 0));
215 
216  // Fields 4: Sizes.
217  for (const auto &en : llvm::enumerate(sizes))
218  memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
219 
220  // Field 5: Strides.
221  for (const auto &en : llvm::enumerate(strides))
222  memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
223 
224  return memRefDescriptor;
225 }
226 
228  OpBuilder &builder, Location loc, TypeRange origTypes,
229  SmallVectorImpl<Value> &operands, bool toDynamic) const {
230  assert(origTypes.size() == operands.size() &&
231  "expected as may original types as operands");
232 
233  // Find operands of unranked memref type and store them.
235  for (unsigned i = 0, e = operands.size(); i < e; ++i)
236  if (origTypes[i].isa<UnrankedMemRefType>())
237  unrankedMemrefs.emplace_back(operands[i]);
238 
239  if (unrankedMemrefs.empty())
240  return success();
241 
242  // Compute allocation sizes.
243  SmallVector<Value, 4> sizes;
245  unrankedMemrefs, sizes);
246 
247  // Get frequently used types.
248  MLIRContext *context = builder.getContext();
249  Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
250  auto i1Type = IntegerType::get(context, 1);
251  Type indexType = getTypeConverter()->getIndexType();
252 
253  // Find the malloc and free, or declare them if necessary.
254  auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
255  LLVM::LLVMFuncOp freeFunc, mallocFunc;
256  if (toDynamic)
257  mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
258  if (!toDynamic)
259  freeFunc = LLVM::lookupOrCreateFreeFn(module);
260 
261  // Initialize shared constants.
262  Value zero =
263  builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
264 
265  unsigned unrankedMemrefPos = 0;
266  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
267  Type type = origTypes[i];
268  if (!type.isa<UnrankedMemRefType>())
269  continue;
270  Value allocationSize = sizes[unrankedMemrefPos++];
271  UnrankedMemRefDescriptor desc(operands[i]);
272 
273  // Allocate memory, copy, and free the source if necessary.
274  Value memory =
275  toDynamic
276  ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
277  .getResult(0)
278  : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
279  /*alignment=*/0);
280  Value source = desc.memRefDescPtr(builder, loc);
281  builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
282  if (!toDynamic)
283  builder.create<LLVM::CallOp>(loc, freeFunc, source);
284 
285  // Create a new descriptor. The same descriptor can be returned multiple
286  // times, attempting to modify its pointer can lead to memory leaks
287  // (allocated twice and overwritten) or double frees (the caller does not
288  // know if the descriptor points to the same memory).
289  Type descriptorType = getTypeConverter()->convertType(type);
290  if (!descriptorType)
291  return failure();
292  auto updatedDesc =
293  UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
294  Value rank = desc.rank(builder, loc);
295  updatedDesc.setRank(builder, loc, rank);
296  updatedDesc.setMemRefDescPtr(builder, loc, memory);
297 
298  operands[i] = updatedDesc;
299  }
300 
301  return success();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // Detail methods
306 //===----------------------------------------------------------------------===//
307 
308 /// Replaces the given operation "op" with a new operation of type "targetOp"
309 /// and given operands.
311  Operation *op, StringRef targetOp, ValueRange operands,
313  unsigned numResults = op->getNumResults();
314 
315  Type packedType;
316  if (numResults != 0) {
317  packedType = typeConverter.packFunctionResults(op->getResultTypes());
318  if (!packedType)
319  return failure();
320  }
321 
322  // Create the operation through state since we don't know its C++ type.
323  Operation *newOp =
324  rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
325  packedType, op->getAttrs());
326 
327  // If the operation produced 0 or 1 result, return them immediately.
328  if (numResults == 0)
329  return rewriter.eraseOp(op), success();
330  if (numResults == 1)
331  return rewriter.replaceOp(op, newOp->getResult(0)), success();
332 
333  // Otherwise, it had been converted to an operation producing a structure.
334  // Extract individual results from the structure and return them as list.
335  SmallVector<Value, 4> results;
336  results.reserve(numResults);
337  for (unsigned i = 0; i < numResults; ++i) {
338  auto type = typeConverter.convertType(op->getResult(i).getType());
339  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
340  op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
341  }
342  rewriter.replaceOp(op, results);
343  return success();
344 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
MLIRContext * getContext() const
Definition: Builders.h:54
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:113
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:37
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:363
Base class for the conversion patterns.
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:68
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:33
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const
Create an LLVM dialect operation defining the given index constant.
Definition: Pattern.cpp:63
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:331
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:188
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:227
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:161
Value getNumElements(Location loc, ArrayRef< Value > shape, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given shape.
Definition: Pattern.cpp:187
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:41
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:171
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:46
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Type packFunctionResults(TypeRange types)
Convert a non-empty list of types to be returned from a function into a supported LLVM IR type...
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:199
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Value memRefDescPtr(OpBuilder &builder, Location loc)
Builds IR extracting ranked memref descriptor ptr.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM...
Definition: Pattern.cpp:106
TypeConverter * typeConverter
An optional type converter for use by this pattern.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:55
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType)
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:28
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:328
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &sizeBytes) const
Computes sizes, strides and buffer size in bytes of memRefType with identity layout.
Definition: Pattern.cpp:120
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:310
Type getIndexType()
Gets the LLVM representation of the index type.
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:132
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:50
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:383
bool isa() const
Definition: Types.h:246
Value rank(OpBuilder &builder, Location loc)
Builds IR extracting the rank from the descriptor.
This class helps build Operations.
Definition: Builders.h:184
This class provides an abstraction over the different types of ranges over Values.
result_type_range getResultTypes()
Definition: Operation.h:352
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
LLVM::LLVMDialect * getDialect()
Returns the LLVM dialect.
Definition: TypeConverter.h:79
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22