MLIR  20.0.0git
AllocLikeConversion.cpp
Go to the documentation of this file.
1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/SymbolTable.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
19  Operation *module, Type indexType) {
20  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
21  if (useGenericFn)
22  return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
23 
24  return LLVM::lookupOrCreateMallocFn(module, indexType);
25 }
26 
27 LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
28  Operation *module, Type indexType) {
29  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
30 
31  if (useGenericFn)
32  return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
33 
34  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
35 }
36 
37 } // end namespace
38 
40  ConversionPatternRewriter &rewriter, Location loc, Value input,
41  Value alignment) {
42  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
43  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
44  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
45  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
46  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
47 }
48 
50  Location loc, Value allocatedPtr,
51  MemRefType memRefType, Type elementPtrType,
52  const LLVMTypeConverter &typeConverter) {
53  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
54  FailureOr<unsigned> maybeMemrefAddrSpace =
55  typeConverter.getMemRefAddressSpace(memRefType);
56  if (failed(maybeMemrefAddrSpace))
57  return Value();
58  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
59  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
60  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
61  loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
62  allocatedPtr);
63  return allocatedPtr;
64 }
65 
67  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
68  Operation *op, Value alignment) const {
69  if (alignment) {
70  // Adjust the allocation size to consider alignment.
71  sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
72  }
73 
74  MemRefType memRefType = getMemRefResultType(op);
75  // Allocate the underlying buffer.
76  Type elementPtrType = this->getElementPtrType(memRefType);
77  LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
79  getIndexType());
80  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
81 
82  Value allocatedPtr =
83  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
84  elementPtrType, *getTypeConverter());
85  if (!allocatedPtr)
86  return std::make_tuple(Value(), Value());
87  Value alignedPtr = allocatedPtr;
88  if (alignment) {
89  // Compute the aligned pointer.
90  Value allocatedInt =
91  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
92  Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
93  alignedPtr =
94  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
95  }
96 
97  return std::make_tuple(allocatedPtr, alignedPtr);
98 }
99 
100 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
101  MemRefType memRefType, Operation *op,
102  const DataLayout *defaultLayout) const {
103  const DataLayout *layout = defaultLayout;
104  if (const DataLayoutAnalysis *analysis =
105  getTypeConverter()->getDataLayoutAnalysis()) {
106  layout = &analysis->getAbove(op);
107  }
108  Type elementType = memRefType.getElementType();
109  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
110  return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
111  *layout);
112  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
114  memRefElementType, *layout);
115  return layout->getTypeSize(elementType);
116 }
117 
118 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
119  MemRefType type, uint64_t factor, Operation *op,
120  const DataLayout *defaultLayout) const {
121  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
122  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
123  if (type.isDynamicDim(i))
124  continue;
125  sizeDivisor = sizeDivisor * type.getDimSize(i);
126  }
127  return sizeDivisor % factor == 0;
128 }
129 
131  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
132  Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
133  Value allocAlignment =
134  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
135 
136  MemRefType memRefType = getMemRefResultType(op);
137  // Function aligned_alloc requires size to be a multiple of alignment; we pad
138  // the size to the next multiple if necessary.
139  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
140  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
141 
142  Type elementPtrType = this->getElementPtrType(memRefType);
143  LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
145  getIndexType());
146  auto results = rewriter.create<LLVM::CallOp>(
147  loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
148 
149  return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
150  elementPtrType, *getTypeConverter());
151 }
152 
154  requiresNumElements = true;
155 }
156 
157 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
158  Operation *op, ArrayRef<Value> operands,
159  ConversionPatternRewriter &rewriter) const {
160  MemRefType memRefType = getMemRefResultType(op);
161  if (!isConvertibleAndHasIdentityMaps(memRefType))
162  return rewriter.notifyMatchFailure(op, "incompatible memref type");
163  auto loc = op->getLoc();
164 
165  // Get actual sizes of the memref as values: static sizes are constant
166  // values and dynamic sizes are passed to 'alloc' as operands. In case of
167  // zero-dimensional memref, assume a scalar (size 1).
168  SmallVector<Value, 4> sizes;
169  SmallVector<Value, 4> strides;
170  Value size;
171 
172  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
173  strides, size, !requiresNumElements);
174 
175  // Allocate the underlying buffer.
176  auto [allocatedPtr, alignedPtr] =
177  this->allocateBuffer(rewriter, loc, size, op);
178 
179  if (!allocatedPtr || !alignedPtr)
180  return rewriter.notifyMatchFailure(loc,
181  "underlying buffer allocation failed");
182 
183  // Create the MemRef descriptor.
184  auto memRefDescriptor = this->createMemRefDescriptor(
185  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
186 
187  // Return the final value of the descriptor.
188  rewriter.replaceOp(op, {memRefDescriptor});
189  return success();
190 }
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, const LLVMTypeConverter &typeConverter)
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:218
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:107
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:100
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:94
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void setRequiresNumElements()
Sets the flag 'requiresNumElements', specifying the Op requires the number of elements instead of the...
virtual std::tuple< Value, Value > allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, Operation *op) const =0
Allocates the underlying buffer.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment)
Computes the aligned value for 'input' as follows: bumped = input + alignement - 1 aligned = bumped -...
Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, const DataLayout *defaultLayout, int64_t alignment) const
Allocates a memory buffer using an aligned allocation method.
static MemRefType getMemRefResultType(Operation *op)
std::tuple< Value, Value > allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, Value alignment) const
Allocates a memory buffer using an allocation method that doesn't guarantee alignment.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53