MLIR  21.0.0git
AllocLikeConversion.cpp
Go to the documentation of this file.
1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/SymbolTable.h"
14 
15 using namespace mlir;
16 
17 static FailureOr<LLVM::LLVMFuncOp>
18 getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
19  Type indexType) {
20  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
21  if (useGenericFn)
22  return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
23 
24  return LLVM::lookupOrCreateMallocFn(module, indexType);
25 }
26 
27 static FailureOr<LLVM::LLVMFuncOp>
28 getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
29  Type indexType) {
30  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
31 
32  if (useGenericFn)
33  return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
34 
35  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
36 }
37 
39  ConversionPatternRewriter &rewriter, Location loc, Value input,
40  Value alignment) {
41  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
42  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
43  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
44  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
45  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
46 }
47 
49  Location loc, Value allocatedPtr,
50  MemRefType memRefType, Type elementPtrType,
51  const LLVMTypeConverter &typeConverter) {
52  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
53  FailureOr<unsigned> maybeMemrefAddrSpace =
54  typeConverter.getMemRefAddressSpace(memRefType);
55  if (failed(maybeMemrefAddrSpace))
56  return Value();
57  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
58  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
59  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
60  loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
61  allocatedPtr);
62  return allocatedPtr;
63 }
64 
66  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
67  Operation *op, Value alignment) const {
68  if (alignment) {
69  // Adjust the allocation size to consider alignment.
70  sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
71  }
72 
73  MemRefType memRefType = getMemRefResultType(op);
74  // Allocate the underlying buffer.
75  Type elementPtrType = this->getElementPtrType(memRefType);
76  if (!elementPtrType) {
77  emitError(loc, "conversion of memref memory space ")
78  << memRefType.getMemorySpace()
79  << " to integer address space "
80  "failed. Consider adding memory space conversions.";
81  }
82  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
84  getIndexType());
85  if (failed(allocFuncOp))
86  return std::make_tuple(Value(), Value());
87  auto results =
88  rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
89 
90  Value allocatedPtr =
91  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
92  elementPtrType, *getTypeConverter());
93  if (!allocatedPtr)
94  return std::make_tuple(Value(), Value());
95  Value alignedPtr = allocatedPtr;
96  if (alignment) {
97  // Compute the aligned pointer.
98  Value allocatedInt =
99  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
100  Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
101  alignedPtr =
102  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
103  }
104 
105  return std::make_tuple(allocatedPtr, alignedPtr);
106 }
107 
108 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
109  MemRefType memRefType, Operation *op,
110  const DataLayout *defaultLayout) const {
111  const DataLayout *layout = defaultLayout;
112  if (const DataLayoutAnalysis *analysis =
113  getTypeConverter()->getDataLayoutAnalysis()) {
114  layout = &analysis->getAbove(op);
115  }
116  Type elementType = memRefType.getElementType();
117  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
118  return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
119  *layout);
120  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
122  memRefElementType, *layout);
123  return layout->getTypeSize(elementType);
124 }
125 
126 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
127  MemRefType type, uint64_t factor, Operation *op,
128  const DataLayout *defaultLayout) const {
129  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
130  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
131  if (type.isDynamicDim(i))
132  continue;
133  sizeDivisor = sizeDivisor * type.getDimSize(i);
134  }
135  return sizeDivisor % factor == 0;
136 }
137 
139  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
140  Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
141  Value allocAlignment =
142  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
143 
144  MemRefType memRefType = getMemRefResultType(op);
145  // Function aligned_alloc requires size to be a multiple of alignment; we pad
146  // the size to the next multiple if necessary.
147  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
148  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
149 
150  Type elementPtrType = this->getElementPtrType(memRefType);
151  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
153  getIndexType());
154  if (failed(allocFuncOp))
155  return Value();
156  auto results = rewriter.create<LLVM::CallOp>(
157  loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
158 
159  return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
160  elementPtrType, *getTypeConverter());
161 }
162 
164  requiresNumElements = true;
165 }
166 
167 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
168  Operation *op, ArrayRef<Value> operands,
169  ConversionPatternRewriter &rewriter) const {
170  MemRefType memRefType = getMemRefResultType(op);
171  if (!isConvertibleAndHasIdentityMaps(memRefType))
172  return rewriter.notifyMatchFailure(op, "incompatible memref type");
173  auto loc = op->getLoc();
174 
175  // Get actual sizes of the memref as values: static sizes are constant
176  // values and dynamic sizes are passed to 'alloc' as operands. In case of
177  // zero-dimensional memref, assume a scalar (size 1).
178  SmallVector<Value, 4> sizes;
179  SmallVector<Value, 4> strides;
180  Value size;
181 
182  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
183  strides, size, !requiresNumElements);
184 
185  // Allocate the underlying buffer.
186  auto [allocatedPtr, alignedPtr] =
187  this->allocateBuffer(rewriter, loc, size, op);
188 
189  if (!allocatedPtr || !alignedPtr)
190  return rewriter.notifyMatchFailure(loc,
191  "underlying buffer allocation failed");
192 
193  // Create the MemRef descriptor.
194  auto memRefDescriptor = this->createMemRefDescriptor(
195  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
196 
197  // Return the final value of the descriptor.
198  rewriter.replaceOp(op, {memRefDescriptor});
199  return success();
200 }
static FailureOr< LLVM::LLVMFuncOp > getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, Type indexType)
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, const LLVMTypeConverter &typeConverter)
static FailureOr< LLVM::LLVMFuncOp > getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, Type indexType)
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:216
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:107
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:100
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void setRequiresNumElements()
Sets the flag 'requiresNumElements', specifying the Op requires the number of elements instead of the...
virtual std::tuple< Value, Value > allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, Operation *op) const =0
Allocates the underlying buffer.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment)
Computes the aligned value for 'input' as follows: bumped = input + alignement - 1 aligned = bumped -...
Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, const DataLayout *defaultLayout, int64_t alignment) const
Allocates a memory buffer using an aligned allocation method.
static MemRefType getMemRefResultType(Operation *op)
std::tuple< Value, Value > allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, Value alignment) const
Allocates a memory buffer using an allocation method that doesn't guarantee alignment.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53