MLIR  19.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
21 
23  StringRef rootOpName, MLIRContext *context,
24  const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25  : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26 
28  return static_cast<const LLVMTypeConverter *>(
30 }
31 
32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33  return *getTypeConverter()->getDialect();
34 }
35 
37  return getTypeConverter()->getIndexType();
38 }
39 
40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
42  getTypeConverter()->getPointerBitwidth(addressSpace));
43 }
44 
47 }
48 
51 }
52 
54  Location loc,
55  Type resultType,
56  int64_t value) {
57  return builder.create<LLVM::ConstantOp>(loc, resultType,
58  builder.getIndexAttr(value));
59 }
60 
62  Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
63  ConversionPatternRewriter &rewriter) const {
64 
65  auto [strides, offset] = getStridesAndOffset(type);
66 
67  MemRefDescriptor memRefDescriptor(memRefDesc);
68  // Use a canonical representation of the start address so that later
69  // optimizations have a longer sequence of instructions to CSE.
70  // If we don't do that we would sprinkle the memref.offset in various
71  // position of the different address computations.
72  Value base =
73  memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
74 
75  Type indexType = getIndexType();
76  Value index;
77  for (int i = 0, e = indices.size(); i < e; ++i) {
78  Value increment = indices[i];
79  if (strides[i] != 1) { // Skip if stride is 1.
80  Value stride =
81  ShapedType::isDynamic(strides[i])
82  ? memRefDescriptor.stride(rewriter, loc, i)
83  : createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
84  increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
85  }
86  index =
87  index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
88  }
89 
90  Type elementPtrType = memRefDescriptor.getElementPtrType();
91  return index ? rewriter.create<LLVM::GEPOp>(
92  loc, elementPtrType,
93  getTypeConverter()->convertType(type.getElementType()),
94  base, index)
95  : base;
96 }
97 
98 // Check if the MemRefType `type` is supported by the lowering. We currently
99 // only support memrefs with identity maps.
101  MemRefType type) const {
102  if (!typeConverter->convertType(type.getElementType()))
103  return false;
104  return type.getLayout().isIdentity();
105 }
106 
108  auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
109  if (failed(addressSpace))
110  return {};
111  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
112 }
113 
115  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
117  SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
118  assert(isConvertibleAndHasIdentityMaps(memRefType) &&
119  "layout maps must have been normalized away");
120  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
121  static_cast<ssize_t>(dynamicSizes.size()) &&
122  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
123 
124  sizes.reserve(memRefType.getRank());
125  unsigned dynamicIndex = 0;
126  Type indexType = getIndexType();
127  for (int64_t size : memRefType.getShape()) {
128  sizes.push_back(
129  size == ShapedType::kDynamic
130  ? dynamicSizes[dynamicIndex++]
131  : createIndexAttrConstant(rewriter, loc, indexType, size));
132  }
133 
134  // Strides: iterate sizes in reverse order and multiply.
135  int64_t stride = 1;
136  Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
137  strides.resize(memRefType.getRank());
138  for (auto i = memRefType.getRank(); i-- > 0;) {
139  strides[i] = runningStride;
140 
141  int64_t staticSize = memRefType.getShape()[i];
142  if (staticSize == 0)
143  continue;
144  bool useSizeAsStride = stride == 1;
145  if (staticSize == ShapedType::kDynamic)
146  stride = ShapedType::kDynamic;
147  if (stride != ShapedType::kDynamic)
148  stride *= staticSize;
149 
150  if (useSizeAsStride)
151  runningStride = sizes[i];
152  else if (stride == ShapedType::kDynamic)
153  runningStride =
154  rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
155  else
156  runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
157  }
158  if (sizeInBytes) {
159  // Buffer size in bytes.
160  Type elementType = typeConverter->convertType(memRefType.getElementType());
161  auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
162  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
163  Value gepPtr = rewriter.create<LLVM::GEPOp>(
164  loc, elementPtrType, elementType, nullPtr, runningStride);
165  size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
166  } else {
167  size = runningStride;
168  }
169 }
170 
172  Location loc, Type type, ConversionPatternRewriter &rewriter) const {
173  // Compute the size of an individual element. This emits the MLIR equivalent
174  // of the following sizeof(...) implementation in LLVM IR:
175  // %0 = getelementptr %elementType* null, %indexType 1
176  // %1 = ptrtoint %elementType* %0 to %indexType
177  // which is a common pattern of getting the size of a type in bytes.
178  Type llvmType = typeConverter->convertType(type);
179  auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
180  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
181  auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
182  nullPtr, ArrayRef<LLVM::GEPArg>{1});
183  return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
184 }
185 
187  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
188  ConversionPatternRewriter &rewriter) const {
189  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
190  static_cast<ssize_t>(dynamicSizes.size()) &&
191  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
192 
193  Type indexType = getIndexType();
194  Value numElements = memRefType.getRank() == 0
195  ? createIndexAttrConstant(rewriter, loc, indexType, 1)
196  : nullptr;
197  unsigned dynamicIndex = 0;
198 
199  // Compute the total number of memref elements.
200  for (int64_t staticSize : memRefType.getShape()) {
201  if (numElements) {
202  Value size =
203  staticSize == ShapedType::kDynamic
204  ? dynamicSizes[dynamicIndex++]
205  : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
206  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
207  } else {
208  numElements =
209  staticSize == ShapedType::kDynamic
210  ? dynamicSizes[dynamicIndex++]
211  : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
212  }
213  }
214  return numElements;
215 }
216 
217 /// Creates and populates the memref descriptor struct given all its fields.
219  Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
220  ArrayRef<Value> sizes, ArrayRef<Value> strides,
221  ConversionPatternRewriter &rewriter) const {
222  auto structType = typeConverter->convertType(memRefType);
223  auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
224 
225  // Field 1: Allocated pointer, used for malloc/free.
226  memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
227 
228  // Field 2: Actual aligned pointer to payload.
229  memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
230 
231  // Field 3: Offset in aligned pointer.
232  Type indexType = getIndexType();
233  memRefDescriptor.setOffset(
234  rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
235 
236  // Fields 4: Sizes.
237  for (const auto &en : llvm::enumerate(sizes))
238  memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
239 
240  // Field 5: Strides.
241  for (const auto &en : llvm::enumerate(strides))
242  memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
243 
244  return memRefDescriptor;
245 }
246 
248  OpBuilder &builder, Location loc, TypeRange origTypes,
249  SmallVectorImpl<Value> &operands, bool toDynamic) const {
250  assert(origTypes.size() == operands.size() &&
251  "expected as may original types as operands");
252 
253  // Find operands of unranked memref type and store them.
255  SmallVector<unsigned> unrankedAddressSpaces;
256  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
257  if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
258  unrankedMemrefs.emplace_back(operands[i]);
259  FailureOr<unsigned> addressSpace =
261  if (failed(addressSpace))
262  return failure();
263  unrankedAddressSpaces.emplace_back(*addressSpace);
264  }
265  }
266 
267  if (unrankedMemrefs.empty())
268  return success();
269 
270  // Compute allocation sizes.
271  SmallVector<Value> sizes;
273  unrankedMemrefs, unrankedAddressSpaces,
274  sizes);
275 
276  // Get frequently used types.
277  Type indexType = getTypeConverter()->getIndexType();
278 
279  // Find the malloc and free, or declare them if necessary.
280  auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
281  LLVM::LLVMFuncOp freeFunc, mallocFunc;
282  if (toDynamic)
283  mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
284  if (!toDynamic)
285  freeFunc = LLVM::lookupOrCreateFreeFn(module);
286 
287  unsigned unrankedMemrefPos = 0;
288  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
289  Type type = origTypes[i];
290  if (!isa<UnrankedMemRefType>(type))
291  continue;
292  Value allocationSize = sizes[unrankedMemrefPos++];
293  UnrankedMemRefDescriptor desc(operands[i]);
294 
295  // Allocate memory, copy, and free the source if necessary.
296  Value memory =
297  toDynamic
298  ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
299  .getResult()
300  : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
302  allocationSize,
303  /*alignment=*/0);
304  Value source = desc.memRefDescPtr(builder, loc);
305  builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
306  if (!toDynamic)
307  builder.create<LLVM::CallOp>(loc, freeFunc, source);
308 
309  // Create a new descriptor. The same descriptor can be returned multiple
310  // times, attempting to modify its pointer can lead to memory leaks
311  // (allocated twice and overwritten) or double frees (the caller does not
312  // know if the descriptor points to the same memory).
313  Type descriptorType = getTypeConverter()->convertType(type);
314  if (!descriptorType)
315  return failure();
316  auto updatedDesc =
317  UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
318  Value rank = desc.rank(builder, loc);
319  updatedDesc.setRank(builder, loc, rank);
320  updatedDesc.setMemRefDescPtr(builder, loc, memory);
321 
322  operands[i] = updatedDesc;
323  }
324 
325  return success();
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // Detail methods
330 //===----------------------------------------------------------------------===//
331 
333  IntegerOverflowFlags overflowFlags) {
334  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
335  iface.setOverflowFlags(overflowFlags);
336 }
337 
338 /// Replaces the given operation "op" with a new operation of type "targetOp"
339 /// and given operands.
341  Operation *op, StringRef targetOp, ValueRange operands,
342  ArrayRef<NamedAttribute> targetAttrs,
343  const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
344  IntegerOverflowFlags overflowFlags) {
345  unsigned numResults = op->getNumResults();
346 
347  SmallVector<Type> resultTypes;
348  if (numResults != 0) {
349  resultTypes.push_back(
350  typeConverter.packOperationResults(op->getResultTypes()));
351  if (!resultTypes.back())
352  return failure();
353  }
354 
355  // Create the operation through state since we don't know its C++ type.
356  Operation *newOp =
357  rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
358  resultTypes, targetAttrs);
359 
360  setNativeProperties(newOp, overflowFlags);
361 
362  // If the operation produced 0 or 1 result, return them immediately.
363  if (numResults == 0)
364  return rewriter.eraseOp(op), success();
365  if (numResults == 1)
366  return rewriter.replaceOp(op, newOp->getResult(0)), success();
367 
368  // Otherwise, it had been converted to an operation producing a structure.
369  // Extract individual results from the structure and return them as list.
370  SmallVector<Value, 4> results;
371  results.reserve(numResults);
372  for (unsigned i = 0; i < numResults; ++i) {
373  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
374  op->getLoc(), newOp->getResult(0), i));
375  }
376  rewriter.replaceOp(op, results);
377  return success();
378 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:218
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:114
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition: Pattern.cpp:186
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:171
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:247
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:107
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref has identity maps and the element type is convertible to LLVM.
Definition: Pattern.cpp:100
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:49
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Definition: TypeConverter.h:92
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:340
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition: Pattern.cpp:332
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26