MLIR  18.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
21 
23  StringRef rootOpName, MLIRContext *context,
24  const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25  : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26 
28  return static_cast<const LLVMTypeConverter *>(
30 }
31 
32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33  return *getTypeConverter()->getDialect();
34 }
35 
37  return getTypeConverter()->getIndexType();
38 }
39 
40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
42  getTypeConverter()->getPointerBitwidth(addressSpace));
43 }
44 
47 }
48 
52 }
53 
55  Location loc,
56  Type resultType,
57  int64_t value) {
58  return builder.create<LLVM::ConstantOp>(loc, resultType,
59  builder.getIndexAttr(value));
60 }
61 
63  Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
64  ConversionPatternRewriter &rewriter) const {
65 
66  auto [strides, offset] = getStridesAndOffset(type);
67 
68  MemRefDescriptor memRefDescriptor(memRefDesc);
69  // Use a canonical representation of the start address so that later
70  // optimizations have a longer sequence of instructions to CSE.
71  // If we don't do that we would sprinkle the memref.offset in various
72  // position of the different address computations.
73  Value base =
74  memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
75 
76  Type indexType = getIndexType();
77  Value index;
78  for (int i = 0, e = indices.size(); i < e; ++i) {
79  Value increment = indices[i];
80  if (strides[i] != 1) { // Skip if stride is 1.
81  Value stride =
82  ShapedType::isDynamic(strides[i])
83  ? memRefDescriptor.stride(rewriter, loc, i)
84  : createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
85  increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
86  }
87  index =
88  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
89  }
90 
91  Type elementPtrType = memRefDescriptor.getElementPtrType();
92  return index ? rewriter.create<LLVM::GEPOp>(
93  loc, elementPtrType,
94  getTypeConverter()->convertType(type.getElementType()),
95  base, index)
96  : base;
97 }
98 
99 // Check if the MemRefType `type` is supported by the lowering. We currently
100 // only support memrefs with identity maps.
102  MemRefType type) const {
103  if (!typeConverter->convertType(type.getElementType()))
104  return false;
105  return type.getLayout().isIdentity();
106 }
107 
109  auto elementType = type.getElementType();
110  auto structElementType = typeConverter->convertType(elementType);
111  auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
112  if (failed(addressSpace))
113  return {};
114  return getTypeConverter()->getPointerType(structElementType, *addressSpace);
115 }
116 
118  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
120  SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
121  assert(isConvertibleAndHasIdentityMaps(memRefType) &&
122  "layout maps must have been normalized away");
123  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
124  static_cast<ssize_t>(dynamicSizes.size()) &&
125  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
126 
127  sizes.reserve(memRefType.getRank());
128  unsigned dynamicIndex = 0;
129  Type indexType = getIndexType();
130  for (int64_t size : memRefType.getShape()) {
131  sizes.push_back(
132  size == ShapedType::kDynamic
133  ? dynamicSizes[dynamicIndex++]
134  : createIndexAttrConstant(rewriter, loc, indexType, size));
135  }
136 
137  // Strides: iterate sizes in reverse order and multiply.
138  int64_t stride = 1;
139  Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
140  strides.resize(memRefType.getRank());
141  for (auto i = memRefType.getRank(); i-- > 0;) {
142  strides[i] = runningStride;
143 
144  int64_t staticSize = memRefType.getShape()[i];
145  if (staticSize == 0)
146  continue;
147  bool useSizeAsStride = stride == 1;
148  if (staticSize == ShapedType::kDynamic)
149  stride = ShapedType::kDynamic;
150  if (stride != ShapedType::kDynamic)
151  stride *= staticSize;
152 
153  if (useSizeAsStride)
154  runningStride = sizes[i];
155  else if (stride == ShapedType::kDynamic)
156  runningStride =
157  rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
158  else
159  runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
160  }
161  if (sizeInBytes) {
162  // Buffer size in bytes.
163  Type elementType = typeConverter->convertType(memRefType.getElementType());
164  Type elementPtrType = getTypeConverter()->getPointerType(elementType);
165  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
166  Value gepPtr = rewriter.create<LLVM::GEPOp>(
167  loc, elementPtrType, elementType, nullPtr, runningStride);
168  size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
169  } else {
170  size = runningStride;
171  }
172 }
173 
175  Location loc, Type type, ConversionPatternRewriter &rewriter) const {
176  // Compute the size of an individual element. This emits the MLIR equivalent
177  // of the following sizeof(...) implementation in LLVM IR:
178  // %0 = getelementptr %elementType* null, %indexType 1
179  // %1 = ptrtoint %elementType* %0 to %indexType
180  // which is a common pattern of getting the size of a type in bytes.
181  Type llvmType = typeConverter->convertType(type);
182  auto convertedPtrType = getTypeConverter()->getPointerType(llvmType);
183  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
184  auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
185  nullPtr, ArrayRef<LLVM::GEPArg>{1});
186  return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
187 }
188 
190  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
191  ConversionPatternRewriter &rewriter) const {
192  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
193  static_cast<ssize_t>(dynamicSizes.size()) &&
194  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
195 
196  Type indexType = getIndexType();
197  Value numElements = memRefType.getRank() == 0
198  ? createIndexAttrConstant(rewriter, loc, indexType, 1)
199  : nullptr;
200  unsigned dynamicIndex = 0;
201 
202  // Compute the total number of memref elements.
203  for (int64_t staticSize : memRefType.getShape()) {
204  if (numElements) {
205  Value size =
206  staticSize == ShapedType::kDynamic
207  ? dynamicSizes[dynamicIndex++]
208  : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
209  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
210  } else {
211  numElements =
212  staticSize == ShapedType::kDynamic
213  ? dynamicSizes[dynamicIndex++]
214  : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
215  }
216  }
217  return numElements;
218 }
219 
220 /// Creates and populates the memref descriptor struct given all its fields.
222  Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
223  ArrayRef<Value> sizes, ArrayRef<Value> strides,
224  ConversionPatternRewriter &rewriter) const {
225  auto structType = typeConverter->convertType(memRefType);
226  auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
227 
228  // Field 1: Allocated pointer, used for malloc/free.
229  memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
230 
231  // Field 2: Actual aligned pointer to payload.
232  memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
233 
234  // Field 3: Offset in aligned pointer.
235  Type indexType = getIndexType();
236  memRefDescriptor.setOffset(
237  rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
238 
239  // Fields 4: Sizes.
240  for (const auto &en : llvm::enumerate(sizes))
241  memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
242 
243  // Field 5: Strides.
244  for (const auto &en : llvm::enumerate(strides))
245  memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
246 
247  return memRefDescriptor;
248 }
249 
251  OpBuilder &builder, Location loc, TypeRange origTypes,
252  SmallVectorImpl<Value> &operands, bool toDynamic) const {
253  assert(origTypes.size() == operands.size() &&
254  "expected as may original types as operands");
255 
256  // Find operands of unranked memref type and store them.
258  SmallVector<unsigned> unrankedAddressSpaces;
259  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
260  if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
261  unrankedMemrefs.emplace_back(operands[i]);
262  FailureOr<unsigned> addressSpace =
264  if (failed(addressSpace))
265  return failure();
266  unrankedAddressSpaces.emplace_back(*addressSpace);
267  }
268  }
269 
270  if (unrankedMemrefs.empty())
271  return success();
272 
273  // Compute allocation sizes.
274  SmallVector<Value> sizes;
276  unrankedMemrefs, unrankedAddressSpaces,
277  sizes);
278 
279  // Get frequently used types.
280  Type indexType = getTypeConverter()->getIndexType();
281 
282  // Find the malloc and free, or declare them if necessary.
283  auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
284  LLVM::LLVMFuncOp freeFunc, mallocFunc;
285  if (toDynamic)
286  mallocFunc = LLVM::lookupOrCreateMallocFn(
287  module, indexType, getTypeConverter()->useOpaquePointers());
288  if (!toDynamic)
289  freeFunc = LLVM::lookupOrCreateFreeFn(
290  module, getTypeConverter()->useOpaquePointers());
291 
292  unsigned unrankedMemrefPos = 0;
293  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
294  Type type = origTypes[i];
295  if (!isa<UnrankedMemRefType>(type))
296  continue;
297  Value allocationSize = sizes[unrankedMemrefPos++];
298  UnrankedMemRefDescriptor desc(operands[i]);
299 
300  // Allocate memory, copy, and free the source if necessary.
301  Value memory =
302  toDynamic
303  ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
304  .getResult()
305  : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
307  allocationSize,
308  /*alignment=*/0);
309  Value source = desc.memRefDescPtr(builder, loc);
310  builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
311  if (!toDynamic)
312  builder.create<LLVM::CallOp>(loc, freeFunc, source);
313 
314  // Create a new descriptor. The same descriptor can be returned multiple
315  // times, attempting to modify its pointer can lead to memory leaks
316  // (allocated twice and overwritten) or double frees (the caller does not
317  // know if the descriptor points to the same memory).
318  Type descriptorType = getTypeConverter()->convertType(type);
319  if (!descriptorType)
320  return failure();
321  auto updatedDesc =
322  UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
323  Value rank = desc.rank(builder, loc);
324  updatedDesc.setRank(builder, loc, rank);
325  updatedDesc.setMemRefDescPtr(builder, loc, memory);
326 
327  operands[i] = updatedDesc;
328  }
329 
330  return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // Detail methods
335 //===----------------------------------------------------------------------===//
336 
337 /// Replaces the given operation "op" with a new operation of type "targetOp"
338 /// and given operands.
340 LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
341  ValueRange operands,
342  ArrayRef<NamedAttribute> targetAttrs,
343  const LLVMTypeConverter &typeConverter,
344  ConversionPatternRewriter &rewriter) {
345  unsigned numResults = op->getNumResults();
346 
347  SmallVector<Type> resultTypes;
348  if (numResults != 0) {
349  resultTypes.push_back(
350  typeConverter.packOperationResults(op->getResultTypes()));
351  if (!resultTypes.back())
352  return failure();
353  }
354 
355  // Create the operation through state since we don't know its C++ type.
356  Operation *newOp =
357  rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
358  resultTypes, targetAttrs);
359 
360  // If the operation produced 0 or 1 result, return them immediately.
361  if (numResults == 0)
362  return rewriter.eraseOp(op), success();
363  if (numResults == 1)
364  return rewriter.replaceOp(op, newOp->getResult(0)), success();
365 
366  // Otherwise, it had been converted to an operation producing a structure.
367  // Extract individual results from the structure and return them as list.
368  SmallVector<Value, 4> results;
369  results.reserve(numResults);
370  for (unsigned i = 0; i < numResults; ++i) {
371  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
372  op->getLoc(), newOp->getResult(0), i));
373  }
374  rewriter.replaceOp(op, results);
375  return success();
376 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:221
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:117
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:62
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition: Pattern.cpp:189
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:174
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:250
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:108
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:54
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:101
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:49
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Definition: TypeConverter.h:91
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0) const
Creates an LLVM pointer type with the given element type and address space.
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
This class helps build Operations.
Definition: Builders.h:206
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:133
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:340
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp, bool opaquePointers)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType, bool opaquePointers)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26