MLIR  17.0.0git
AllocLikeConversion.cpp
Go to the documentation of this file.
1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 
14 using namespace mlir;
15 
16 namespace {
17 // TODO: Fix the LLVM utilities for looking up functions to take Operation*
18 // with SymbolTable trait instead of ModuleOp and make similar change here. This
19 // allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
20 // of getParentOfType<ModuleOp> to pass down the operation.
21 LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter,
22  ModuleOp module, Type indexType) {
23  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
24 
25  if (useGenericFn)
26  return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
27 
28  return LLVM::lookupOrCreateMallocFn(module, indexType);
29 }
30 
31 LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter,
32  ModuleOp module, Type indexType) {
33  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
34 
35  if (useGenericFn)
36  return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
37 
38  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
39 }
40 
41 } // end namespace
42 
44  ConversionPatternRewriter &rewriter, Location loc, Value input,
45  Value alignment) {
46  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
47  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
48  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
49  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
50  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
51 }
52 
54  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
55  Operation *op, Value alignment) const {
56  if (alignment) {
57  // Adjust the allocation size to consider alignment.
58  sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
59  }
60 
61  MemRefType memRefType = getMemRefResultType(op);
62  // Allocate the underlying buffer.
63  Type elementPtrType = this->getElementPtrType(memRefType);
64  LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
65  getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
66  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
67  Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
68  results.getResult());
69 
70  Value alignedPtr = allocatedPtr;
71  if (alignment) {
72  // Compute the aligned pointer.
73  Value allocatedInt =
74  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
75  Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
76  alignedPtr =
77  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
78  }
79 
80  return std::make_tuple(allocatedPtr, alignedPtr);
81 }
82 
83 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
84  MemRefType memRefType, Operation *op,
85  const DataLayout *defaultLayout) const {
86  const DataLayout *layout = defaultLayout;
87  if (const DataLayoutAnalysis *analysis =
88  getTypeConverter()->getDataLayoutAnalysis()) {
89  layout = &analysis->getAbove(op);
90  }
91  Type elementType = memRefType.getElementType();
92  if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
93  return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
94  *layout);
95  if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
97  memRefElementType, *layout);
98  return layout->getTypeSize(elementType);
99 }
100 
101 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
102  MemRefType type, uint64_t factor, Operation *op,
103  const DataLayout *defaultLayout) const {
104  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
105  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
106  if (ShapedType::isDynamic(type.getDimSize(i)))
107  continue;
108  sizeDivisor = sizeDivisor * type.getDimSize(i);
109  }
110  return sizeDivisor % factor == 0;
111 }
112 
114  ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
115  Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
116  Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
117 
118  MemRefType memRefType = getMemRefResultType(op);
119  // Function aligned_alloc requires size to be a multiple of alignment; we pad
120  // the size to the next multiple if necessary.
121  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
122  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
123 
124  Type elementPtrType = this->getElementPtrType(memRefType);
125  LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
126  getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
127  auto results = rewriter.create<LLVM::CallOp>(
128  loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
129  Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
130  results.getResult());
131 
132  return allocatedPtr;
133 }
134 
135 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
136  Operation *op, ArrayRef<Value> operands,
137  ConversionPatternRewriter &rewriter) const {
138  MemRefType memRefType = getMemRefResultType(op);
139  if (!isConvertibleAndHasIdentityMaps(memRefType))
140  return rewriter.notifyMatchFailure(op, "incompatible memref type");
141  auto loc = op->getLoc();
142 
143  // Get actual sizes of the memref as values: static sizes are constant
144  // values and dynamic sizes are passed to 'alloc' as operands. In case of
145  // zero-dimensional memref, assume a scalar (size 1).
146  SmallVector<Value, 4> sizes;
147  SmallVector<Value, 4> strides;
148  Value sizeBytes;
149  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
150  strides, sizeBytes);
151 
152  // Allocate the underlying buffer.
153  auto [allocatedPtr, alignedPtr] =
154  this->allocateBuffer(rewriter, loc, sizeBytes, op);
155 
156  // Create the MemRef descriptor.
157  auto memRefDescriptor = this->createMemRefDescriptor(
158  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
159 
160  // Return the final value of the descriptor.
161  rewriter.replaceOp(op, {memRefDescriptor});
162  return success();
163 }
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:194
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:28
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &sizeBytes) const
Computes sizes, strides and buffer size in bytes of memRefType with identity layout.
Definition: Pattern.cpp:116
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:109
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:55
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:102
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:81
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout)
Returns the size of the memref descriptor object in bytes.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout)
Returns the size of the unranked memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:198
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:213
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U dyn_cast() const
Definition: Types.h:311
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:350
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp, Type indexType)
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
virtual std::tuple< Value, Value > allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const =0
Allocates the underlying buffer.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:37
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment)
Computes the aligned value for 'input' as follows: bumped = input + alignement - 1 aligned = bumped -...
Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, const DataLayout *defaultLayout, int64_t alignment) const
Allocates a memory buffer using an aligned allocation method.
static MemRefType getMemRefResultType(Operation *op)
std::tuple< Value, Value > allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, Value alignment) const
Allocates a memory buffer using an allocation method that doesn't guarantee alignment.
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const
Create an LLVM dialect operation defining the given index constant.
Definition: Pattern.cpp:63
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26