MLIR  20.0.0git
AllocLikeConversion.cpp
Go to the documentation of this file.
1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/SymbolTable.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
19  Operation *module, Type indexType) {
20  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
21  if (useGenericFn)
22  return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
23 
24  return LLVM::lookupOrCreateMallocFn(module, indexType);
25 }
26 
27 LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
28  Operation *module, Type indexType) {
29  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
30 
31  if (useGenericFn)
32  return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
33 
34  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
35 }
36 
37 } // end namespace
38 
40  ConversionPatternRewriter &rewriter, Location loc, Value input,
41  Value alignment) {
42  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
43  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
44  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
45  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
46  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
47 }
48 
50  Location loc, Value allocatedPtr,
51  MemRefType memRefType, Type elementPtrType,
52  const LLVMTypeConverter &typeConverter) {
53  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
54  FailureOr<unsigned> maybeMemrefAddrSpace =
55  typeConverter.getMemRefAddressSpace(memRefType);
56  if (failed(maybeMemrefAddrSpace))
57  return Value();
58  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
59  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
60  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
61  loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
62  allocatedPtr);
63  return allocatedPtr;
64 }
65 
67  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
68  Operation *op, Value alignment) const {
69  if (alignment) {
70  // Adjust the allocation size to consider alignment.
71  sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
72  }
73 
74  MemRefType memRefType = getMemRefResultType(op);
75  // Allocate the underlying buffer.
76  Type elementPtrType = this->getElementPtrType(memRefType);
77  if (!elementPtrType) {
78  emitError(loc, "conversion of memref memory space ")
79  << memRefType.getMemorySpace()
80  << " to integer address space "
81  "failed. Consider adding memory space conversions.";
82  }
83  LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
85  getIndexType());
86  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
87 
88  Value allocatedPtr =
89  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
90  elementPtrType, *getTypeConverter());
91  if (!allocatedPtr)
92  return std::make_tuple(Value(), Value());
93  Value alignedPtr = allocatedPtr;
94  if (alignment) {
95  // Compute the aligned pointer.
96  Value allocatedInt =
97  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
98  Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
99  alignedPtr =
100  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
101  }
102 
103  return std::make_tuple(allocatedPtr, alignedPtr);
104 }
105 
106 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
107  MemRefType memRefType, Operation *op,
108  const DataLayout *defaultLayout) const {
109  const DataLayout *layout = defaultLayout;
110  if (const DataLayoutAnalysis *analysis =
111  getTypeConverter()->getDataLayoutAnalysis()) {
112  layout = &analysis->getAbove(op);
113  }
114  Type elementType = memRefType.getElementType();
115  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
116  return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
117  *layout);
118  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
120  memRefElementType, *layout);
121  return layout->getTypeSize(elementType);
122 }
123 
124 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
125  MemRefType type, uint64_t factor, Operation *op,
126  const DataLayout *defaultLayout) const {
127  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
128  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
129  if (type.isDynamicDim(i))
130  continue;
131  sizeDivisor = sizeDivisor * type.getDimSize(i);
132  }
133  return sizeDivisor % factor == 0;
134 }
135 
137  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
138  Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
139  Value allocAlignment =
140  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
141 
142  MemRefType memRefType = getMemRefResultType(op);
143  // Function aligned_alloc requires size to be a multiple of alignment; we pad
144  // the size to the next multiple if necessary.
145  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
146  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
147 
148  Type elementPtrType = this->getElementPtrType(memRefType);
149  LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
151  getIndexType());
152  auto results = rewriter.create<LLVM::CallOp>(
153  loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
154 
155  return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
156  elementPtrType, *getTypeConverter());
157 }
158 
160  requiresNumElements = true;
161 }
162 
163 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
164  Operation *op, ArrayRef<Value> operands,
165  ConversionPatternRewriter &rewriter) const {
166  MemRefType memRefType = getMemRefResultType(op);
167  if (!isConvertibleAndHasIdentityMaps(memRefType))
168  return rewriter.notifyMatchFailure(op, "incompatible memref type");
169  auto loc = op->getLoc();
170 
171  // Get actual sizes of the memref as values: static sizes are constant
172  // values and dynamic sizes are passed to 'alloc' as operands. In case of
173  // zero-dimensional memref, assume a scalar (size 1).
174  SmallVector<Value, 4> sizes;
175  SmallVector<Value, 4> strides;
176  Value size;
177 
178  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
179  strides, size, !requiresNumElements);
180 
181  // Allocate the underlying buffer.
182  auto [allocatedPtr, alignedPtr] =
183  this->allocateBuffer(rewriter, loc, size, op);
184 
185  if (!allocatedPtr || !alignedPtr)
186  return rewriter.notifyMatchFailure(loc,
187  "underlying buffer allocation failed");
188 
189  // Create the MemRef descriptor.
190  auto memRefDescriptor = this->createMemRefDescriptor(
191  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
192 
193  // Return the final value of the descriptor.
194  rewriter.replaceOp(op, {memRefDescriptor});
195  return success();
196 }
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, const LLVMTypeConverter &typeConverter)
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:216
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:107
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:100
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void setRequiresNumElements()
Sets the flag 'requiresNumElements', specifying the Op requires the number of elements instead of the...
virtual std::tuple< Value, Value > allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, Operation *op) const =0
Allocates the underlying buffer.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment)
Computes the aligned value for 'input' as follows: bumped = input + alignement - 1 aligned = bumped -...
Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, const DataLayout *defaultLayout, int64_t alignment) const
Allocates a memory buffer using an aligned allocation method.
static MemRefType getMemRefResultType(Operation *op)
std::tuple< Value, Value > allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, Value alignment) const
Allocates a memory buffer using an allocation method that doesn't guarantee alignment.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53