MLIR  19.0.0git
AllocLikeConversion.cpp
Go to the documentation of this file.
1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 
14 using namespace mlir;
15 
16 namespace {
17 // TODO: Fix the LLVM utilities for looking up functions to take Operation*
18 // with SymbolTable trait instead of ModuleOp and make similar change here. This
19 // allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
20 // of getParentOfType<ModuleOp> to pass down the operation.
21 LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
22  ModuleOp module, Type indexType) {
23  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
24 
25  if (useGenericFn)
26  return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
27 
28  return LLVM::lookupOrCreateMallocFn(module, indexType);
29 }
30 
31 LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
32  ModuleOp module, Type indexType) {
33  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
34 
35  if (useGenericFn)
36  return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
37 
38  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
39 }
40 
41 } // end namespace
42 
44  ConversionPatternRewriter &rewriter, Location loc, Value input,
45  Value alignment) {
46  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
47  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
48  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
49  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
50  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
51 }
52 
54  Location loc, Value allocatedPtr,
55  MemRefType memRefType, Type elementPtrType,
56  const LLVMTypeConverter &typeConverter) {
57  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
58  FailureOr<unsigned> maybeMemrefAddrSpace =
59  typeConverter.getMemRefAddressSpace(memRefType);
60  if (failed(maybeMemrefAddrSpace))
61  return Value();
62  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
63  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
64  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
65  loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
66  allocatedPtr);
67  return allocatedPtr;
68 }
69 
71  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
72  Operation *op, Value alignment) const {
73  if (alignment) {
74  // Adjust the allocation size to consider alignment.
75  sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
76  }
77 
78  MemRefType memRefType = getMemRefResultType(op);
79  // Allocate the underlying buffer.
80  Type elementPtrType = this->getElementPtrType(memRefType);
81  LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
82  getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
83  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
84 
85  Value allocatedPtr =
86  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
87  elementPtrType, *getTypeConverter());
88  if (!allocatedPtr)
89  return std::make_tuple(Value(), Value());
90  Value alignedPtr = allocatedPtr;
91  if (alignment) {
92  // Compute the aligned pointer.
93  Value allocatedInt =
94  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
95  Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
96  alignedPtr =
97  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
98  }
99 
100  return std::make_tuple(allocatedPtr, alignedPtr);
101 }
102 
103 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
104  MemRefType memRefType, Operation *op,
105  const DataLayout *defaultLayout) const {
106  const DataLayout *layout = defaultLayout;
107  if (const DataLayoutAnalysis *analysis =
108  getTypeConverter()->getDataLayoutAnalysis()) {
109  layout = &analysis->getAbove(op);
110  }
111  Type elementType = memRefType.getElementType();
112  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
113  return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
114  *layout);
115  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
117  memRefElementType, *layout);
118  return layout->getTypeSize(elementType);
119 }
120 
121 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
122  MemRefType type, uint64_t factor, Operation *op,
123  const DataLayout *defaultLayout) const {
124  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
125  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
126  if (type.isDynamicDim(i))
127  continue;
128  sizeDivisor = sizeDivisor * type.getDimSize(i);
129  }
130  return sizeDivisor % factor == 0;
131 }
132 
134  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
135  Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
136  Value allocAlignment =
137  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
138 
139  MemRefType memRefType = getMemRefResultType(op);
140  // Function aligned_alloc requires size to be a multiple of alignment; we pad
141  // the size to the next multiple if necessary.
142  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
143  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
144 
145  Type elementPtrType = this->getElementPtrType(memRefType);
146  LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
147  getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
148  auto results = rewriter.create<LLVM::CallOp>(
149  loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
150 
151  return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
152  elementPtrType, *getTypeConverter());
153 }
154 
156  requiresNumElements = true;
157 }
158 
159 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
160  Operation *op, ArrayRef<Value> operands,
161  ConversionPatternRewriter &rewriter) const {
162  MemRefType memRefType = getMemRefResultType(op);
163  if (!isConvertibleAndHasIdentityMaps(memRefType))
164  return rewriter.notifyMatchFailure(op, "incompatible memref type");
165  auto loc = op->getLoc();
166 
167  // Get actual sizes of the memref as values: static sizes are constant
168  // values and dynamic sizes are passed to 'alloc' as operands. In case of
169  // zero-dimensional memref, assume a scalar (size 1).
170  SmallVector<Value, 4> sizes;
171  SmallVector<Value, 4> strides;
172  Value size;
173 
174  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
175  strides, size, !requiresNumElements);
176 
177  // Allocate the underlying buffer.
178  auto [allocatedPtr, alignedPtr] =
179  this->allocateBuffer(rewriter, loc, size, op);
180 
181  if (!allocatedPtr || !alignedPtr)
182  return rewriter.notifyMatchFailure(loc,
183  "underlying buffer allocation failed");
184 
185  // Create the MemRef descriptor.
186  auto memRefDescriptor = this->createMemRefDescriptor(
187  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
188 
189  // Return the final value of the descriptor.
190  rewriter.replaceOp(op, {memRefDescriptor});
191  return success();
192 }
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, const LLVMTypeConverter &typeConverter)
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:218
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:107
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:100
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:94
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp, Type indexType)
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
void setRequiresNumElements()
Sets the flag 'requiresNumElements', specifying the Op requires the number of elements instead of the...
virtual std::tuple< Value, Value > allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, Operation *op) const =0
Allocates the underlying buffer.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment)
Computes the aligned value for 'input' as follows: bumped = input + alignement - 1 aligned = bumped -...
Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, const DataLayout *defaultLayout, int64_t alignment) const
Allocates a memory buffer using an aligned allocation method.
static MemRefType getMemRefResultType(Operation *op)
std::tuple< Value, Value > allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, Value alignment) const
Allocates a memory buffer using an allocation method that doesn't guarantee alignment.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26