MLIR  14.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // ConvertToLLVMPattern
19 //===----------------------------------------------------------------------===//
20 
22  MLIRContext *context,
23  LLVMTypeConverter &typeConverter,
24  PatternBenefit benefit)
25  : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26 
28  return static_cast<LLVMTypeConverter *>(
30 }
31 
32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33  return *getTypeConverter()->getDialect();
34 }
35 
37  return getTypeConverter()->getIndexType();
38 }
39 
40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
41  return IntegerType::get(&getTypeConverter()->getContext(),
42  getTypeConverter()->getPointerBitwidth(addressSpace));
43 }
44 
46  return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
47 }
48 
51  IntegerType::get(&getTypeConverter()->getContext(), 8));
52 }
53 
55  Location loc,
56  Type resultType,
57  int64_t value) {
58  return builder.create<LLVM::ConstantOp>(
59  loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
60 }
61 
63  ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
64  return createIndexAttrConstant(builder, loc, getIndexType(), value);
65 }
66 
68  Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
69  ConversionPatternRewriter &rewriter) const {
70 
71  int64_t offset;
73  auto successStrides = getStridesAndOffset(type, strides, offset);
74  assert(succeeded(successStrides) && "unexpected non-strided memref");
75  (void)successStrides;
76 
77  MemRefDescriptor memRefDescriptor(memRefDesc);
78  Value base = memRefDescriptor.alignedPtr(rewriter, loc);
79 
80  Value index;
81  if (offset != 0) // Skip if offset is zero.
82  index = ShapedType::isDynamicStrideOrOffset(offset)
83  ? memRefDescriptor.offset(rewriter, loc)
84  : createIndexConstant(rewriter, loc, offset);
85 
86  for (int i = 0, e = indices.size(); i < e; ++i) {
87  Value increment = indices[i];
88  if (strides[i] != 1) { // Skip if stride is 1.
89  Value stride = ShapedType::isDynamicStrideOrOffset(strides[i])
90  ? memRefDescriptor.stride(rewriter, loc, i)
91  : createIndexConstant(rewriter, loc, strides[i]);
92  increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
93  }
94  index =
95  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
96  }
97 
98  Type elementPtrType = memRefDescriptor.getElementPtrType();
99  return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
100  : base;
101 }
102 
103 // Check if the MemRefType `type` is supported by the lowering. We currently
104 // only support memrefs with identity maps.
106  MemRefType type) const {
107  if (!typeConverter->convertType(type.getElementType()))
108  return false;
109  return type.getLayout().isIdentity();
110 }
111 
113  auto elementType = type.getElementType();
114  auto structElementType = typeConverter->convertType(elementType);
115  return LLVM::LLVMPointerType::get(structElementType,
116  type.getMemorySpaceAsInt());
117 }
118 
120  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
122  SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
123  assert(isConvertibleAndHasIdentityMaps(memRefType) &&
124  "layout maps must have been normalized away");
125  assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
126  static_cast<ssize_t>(dynamicSizes.size()) &&
127  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
128 
129  sizes.reserve(memRefType.getRank());
130  unsigned dynamicIndex = 0;
131  for (int64_t size : memRefType.getShape()) {
132  sizes.push_back(size == ShapedType::kDynamicSize
133  ? dynamicSizes[dynamicIndex++]
134  : createIndexConstant(rewriter, loc, size));
135  }
136 
137  // Strides: iterate sizes in reverse order and multiply.
138  int64_t stride = 1;
139  Value runningStride = createIndexConstant(rewriter, loc, 1);
140  strides.resize(memRefType.getRank());
141  for (auto i = memRefType.getRank(); i-- > 0;) {
142  strides[i] = runningStride;
143 
144  int64_t size = memRefType.getShape()[i];
145  if (size == 0)
146  continue;
147  bool useSizeAsStride = stride == 1;
148  if (size == ShapedType::kDynamicSize)
149  stride = ShapedType::kDynamicSize;
150  if (stride != ShapedType::kDynamicSize)
151  stride *= size;
152 
153  if (useSizeAsStride)
154  runningStride = sizes[i];
155  else if (stride == ShapedType::kDynamicSize)
156  runningStride =
157  rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
158  else
159  runningStride = createIndexConstant(rewriter, loc, stride);
160  }
161 
162  // Buffer size in bytes.
163  Type elementPtrType = getElementPtrType(memRefType);
164  Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
165  Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
166  ArrayRef<Value>{runningStride});
167  sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
168 }
169 
171  Location loc, Type type, ConversionPatternRewriter &rewriter) const {
172  // Compute the size of an individual element. This emits the MLIR equivalent
173  // of the following sizeof(...) implementation in LLVM IR:
174  // %0 = getelementptr %elementType* null, %indexType 1
175  // %1 = ptrtoint %elementType* %0 to %indexType
176  // which is a common pattern of getting the size of a type in bytes.
177  auto convertedPtrType =
179  auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
180  auto gep = rewriter.create<LLVM::GEPOp>(
181  loc, convertedPtrType, nullPtr,
182  ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)});
183  return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
184 }
185 
187  Location loc, ArrayRef<Value> shape,
188  ConversionPatternRewriter &rewriter) const {
189  // Compute the total number of memref elements.
190  Value numElements =
191  shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
192  for (unsigned i = 1, e = shape.size(); i < e; ++i)
193  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
194  return numElements;
195 }
196 
197 /// Creates and populates the memref descriptor struct given all its fields.
199  Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
200  ArrayRef<Value> sizes, ArrayRef<Value> strides,
201  ConversionPatternRewriter &rewriter) const {
202  auto structType = typeConverter->convertType(memRefType);
203  auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
204 
205  // Field 1: Allocated pointer, used for malloc/free.
206  memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
207 
208  // Field 2: Actual aligned pointer to payload.
209  memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
210 
211  // Field 3: Offset in aligned pointer.
212  memRefDescriptor.setOffset(rewriter, loc,
213  createIndexConstant(rewriter, loc, 0));
214 
215  // Fields 4: Sizes.
216  for (const auto &en : llvm::enumerate(sizes))
217  memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
218 
219  // Field 5: Strides.
220  for (const auto &en : llvm::enumerate(strides))
221  memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
222 
223  return memRefDescriptor;
224 }
225 
227  OpBuilder &builder, Location loc, TypeRange origTypes,
228  SmallVectorImpl<Value> &operands, bool toDynamic) const {
229  assert(origTypes.size() == operands.size() &&
230  "expected as may original types as operands");
231 
232  // Find operands of unranked memref type and store them.
234  for (unsigned i = 0, e = operands.size(); i < e; ++i)
235  if (origTypes[i].isa<UnrankedMemRefType>())
236  unrankedMemrefs.emplace_back(operands[i]);
237 
238  if (unrankedMemrefs.empty())
239  return success();
240 
241  // Compute allocation sizes.
242  SmallVector<Value, 4> sizes;
244  unrankedMemrefs, sizes);
245 
246  // Get frequently used types.
247  MLIRContext *context = builder.getContext();
248  Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
249  auto i1Type = IntegerType::get(context, 1);
250  Type indexType = getTypeConverter()->getIndexType();
251 
252  // Find the malloc and free, or declare them if necessary.
253  auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
254  LLVM::LLVMFuncOp freeFunc, mallocFunc;
255  if (toDynamic)
256  mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
257  if (!toDynamic)
258  freeFunc = LLVM::lookupOrCreateFreeFn(module);
259 
260  // Initialize shared constants.
261  Value zero =
262  builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
263 
264  unsigned unrankedMemrefPos = 0;
265  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
266  Type type = origTypes[i];
267  if (!type.isa<UnrankedMemRefType>())
268  continue;
269  Value allocationSize = sizes[unrankedMemrefPos++];
270  UnrankedMemRefDescriptor desc(operands[i]);
271 
272  // Allocate memory, copy, and free the source if necessary.
273  Value memory =
274  toDynamic
275  ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
276  .getResult(0)
277  : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
278  /*alignment=*/0);
279  Value source = desc.memRefDescPtr(builder, loc);
280  builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
281  if (!toDynamic)
282  builder.create<LLVM::CallOp>(loc, freeFunc, source);
283 
284  // Create a new descriptor. The same descriptor can be returned multiple
285  // times, attempting to modify its pointer can lead to memory leaks
286  // (allocated twice and overwritten) or double frees (the caller does not
287  // know if the descriptor points to the same memory).
288  Type descriptorType = getTypeConverter()->convertType(type);
289  if (!descriptorType)
290  return failure();
291  auto updatedDesc =
292  UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
293  Value rank = desc.rank(builder, loc);
294  updatedDesc.setRank(builder, loc, rank);
295  updatedDesc.setMemRefDescPtr(builder, loc, memory);
296 
297  operands[i] = updatedDesc;
298  }
299 
300  return success();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Detail methods
305 //===----------------------------------------------------------------------===//
306 
307 /// Replaces the given operation "op" with a new operation of type "targetOp"
308 /// and given operands.
310  Operation *op, StringRef targetOp, ValueRange operands,
312  unsigned numResults = op->getNumResults();
313 
314  Type packedType;
315  if (numResults != 0) {
316  packedType = typeConverter.packFunctionResults(op->getResultTypes());
317  if (!packedType)
318  return failure();
319  }
320 
321  // Create the operation through state since we don't know its C++ type.
322  OperationState state(op->getLoc(), targetOp);
323  state.addTypes(packedType);
324  state.addOperands(operands);
325  state.addAttributes(op->getAttrs());
326  Operation *newOp = rewriter.createOperation(state);
327 
328  // If the operation produced 0 or 1 result, return them immediately.
329  if (numResults == 0)
330  return rewriter.eraseOp(op), success();
331  if (numResults == 1)
332  return rewriter.replaceOp(op, newOp->getResult(0)), success();
333 
334  // Otherwise, it had been converted to an operation producing a structure.
335  // Extract individual results from the structure and return them as list.
336  SmallVector<Value, 4> results;
337  results.reserve(numResults);
338  for (unsigned i = 0; i < numResults; ++i) {
339  auto type = typeConverter.convertType(op->getResult(i).getType());
340  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
341  op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
342  }
343  rewriter.replaceOp(op, results);
344  return success();
345 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
MLIRContext * getContext() const
Definition: Builders.h:54
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:112
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:308
Base class for the conversion patterns.
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:67
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
static LLVMPointerType get(Type pointee, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:165
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
Operation * createOperation(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const
Create an LLVM dialect operation defining the given index constant.
Definition: Pattern.cpp:62
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:226
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Value getNumElements(Location loc, ArrayRef< Value > shape, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given shape.
Definition: Pattern.cpp:186
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:170
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Type packFunctionResults(TypeRange types)
Convert a non-empty list of types to be returned from a function into a supported LLVM IR type...
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:198
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Value memRefDescPtr(OpBuilder &builder, Location loc)
Builds IR extracting ranked memref descriptor ptr.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM...
Definition: Pattern.cpp:105
TypeConverter * typeConverter
An optional type converter for use by this pattern.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:54
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType)
Type getType() const
Return the type of this value.
Definition: Value.h:117
IndexType getIndexType()
Definition: Builders.cpp:48
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Conversion from types in the Standard dialect to the LLVM IR dialect.
Definition: TypeConverter.h:30
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &sizeBytes) const
Computes sizes, strides and buffer size in bytes of memRefType with identity layout.
Definition: Pattern.cpp:119
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:309
Type getIndexType()
Gets the LLVM representation of the index type.
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:132
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:49
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:376
bool isa() const
Definition: Types.h:234
Value rank(OpBuilder &builder, Location loc)
Builds IR extracting the rank from the descriptor.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
result_type_range getResultTypes()
Definition: Operation.h:297
static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
LLVM::LLVMDialect * getDialect()
Returns the LLVM dialect.
Definition: TypeConverter.h:79
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:21