MLIR  21.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
21 
23  StringRef rootOpName, MLIRContext *context,
24  const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25  : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26 
28  return static_cast<const LLVMTypeConverter *>(
30 }
31 
32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33  return *getTypeConverter()->getDialect();
34 }
35 
37  return getTypeConverter()->getIndexType();
38 }
39 
40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
42  getTypeConverter()->getPointerBitwidth(addressSpace));
43 }
44 
47 }
48 
51 }
52 
54  Location loc,
55  Type resultType,
56  int64_t value) {
57  return builder.create<LLVM::ConstantOp>(loc, resultType,
58  builder.getIndexAttr(value));
59 }
60 
62  ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
63  Value memRefDesc, ValueRange indices,
64  LLVM::GEPNoWrapFlags noWrapFlags) const {
65  return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type,
66  memRefDesc, indices, noWrapFlags);
67 }
68 
69 // Check if the MemRefType `type` is supported by the lowering. We currently
70 // only support memrefs with identity maps.
72  MemRefType type) const {
73  if (!type.getLayout().isIdentity())
74  return false;
75  return static_cast<bool>(typeConverter->convertType(type));
76 }
77 
79  auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
80  if (failed(addressSpace))
81  return {};
82  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
83 }
84 
86  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
88  SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
89  assert(isConvertibleAndHasIdentityMaps(memRefType) &&
90  "layout maps must have been normalized away");
91  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
92  static_cast<ssize_t>(dynamicSizes.size()) &&
93  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
94 
95  sizes.reserve(memRefType.getRank());
96  unsigned dynamicIndex = 0;
97  Type indexType = getIndexType();
98  for (int64_t size : memRefType.getShape()) {
99  sizes.push_back(
100  size == ShapedType::kDynamic
101  ? dynamicSizes[dynamicIndex++]
102  : createIndexAttrConstant(rewriter, loc, indexType, size));
103  }
104 
105  // Strides: iterate sizes in reverse order and multiply.
106  int64_t stride = 1;
107  Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
108  strides.resize(memRefType.getRank());
109  for (auto i = memRefType.getRank(); i-- > 0;) {
110  strides[i] = runningStride;
111 
112  int64_t staticSize = memRefType.getShape()[i];
113  bool useSizeAsStride = stride == 1;
114  if (staticSize == ShapedType::kDynamic)
115  stride = ShapedType::kDynamic;
116  if (stride != ShapedType::kDynamic)
117  stride *= staticSize;
118 
119  if (useSizeAsStride)
120  runningStride = sizes[i];
121  else if (stride == ShapedType::kDynamic)
122  runningStride =
123  rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
124  else
125  runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
126  }
127  if (sizeInBytes) {
128  // Buffer size in bytes.
129  Type elementType = typeConverter->convertType(memRefType.getElementType());
130  auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
131  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
132  Value gepPtr = rewriter.create<LLVM::GEPOp>(
133  loc, elementPtrType, elementType, nullPtr, runningStride);
134  size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
135  } else {
136  size = runningStride;
137  }
138 }
139 
141  Location loc, Type type, ConversionPatternRewriter &rewriter) const {
142  // Compute the size of an individual element. This emits the MLIR equivalent
143  // of the following sizeof(...) implementation in LLVM IR:
144  // %0 = getelementptr %elementType* null, %indexType 1
145  // %1 = ptrtoint %elementType* %0 to %indexType
146  // which is a common pattern of getting the size of a type in bytes.
147  Type llvmType = typeConverter->convertType(type);
148  auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
149  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
150  auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
151  nullPtr, ArrayRef<LLVM::GEPArg>{1});
152  return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
153 }
154 
156  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
157  ConversionPatternRewriter &rewriter) const {
158  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
159  static_cast<ssize_t>(dynamicSizes.size()) &&
160  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
161 
162  Type indexType = getIndexType();
163  Value numElements = memRefType.getRank() == 0
164  ? createIndexAttrConstant(rewriter, loc, indexType, 1)
165  : nullptr;
166  unsigned dynamicIndex = 0;
167 
168  // Compute the total number of memref elements.
169  for (int64_t staticSize : memRefType.getShape()) {
170  if (numElements) {
171  Value size =
172  staticSize == ShapedType::kDynamic
173  ? dynamicSizes[dynamicIndex++]
174  : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
175  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
176  } else {
177  numElements =
178  staticSize == ShapedType::kDynamic
179  ? dynamicSizes[dynamicIndex++]
180  : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
181  }
182  }
183  return numElements;
184 }
185 
186 /// Creates and populates the memref descriptor struct given all its fields.
188  Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
189  ArrayRef<Value> sizes, ArrayRef<Value> strides,
190  ConversionPatternRewriter &rewriter) const {
191  auto structType = typeConverter->convertType(memRefType);
192  auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);
193 
194  // Field 1: Allocated pointer, used for malloc/free.
195  memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
196 
197  // Field 2: Actual aligned pointer to payload.
198  memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
199 
200  // Field 3: Offset in aligned pointer.
201  Type indexType = getIndexType();
202  memRefDescriptor.setOffset(
203  rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
204 
205  // Fields 4: Sizes.
206  for (const auto &en : llvm::enumerate(sizes))
207  memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
208 
209  // Field 5: Strides.
210  for (const auto &en : llvm::enumerate(strides))
211  memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
212 
213  return memRefDescriptor;
214 }
215 
217  OpBuilder &builder, Location loc, TypeRange origTypes,
218  SmallVectorImpl<Value> &operands, bool toDynamic) const {
219  assert(origTypes.size() == operands.size() &&
220  "expected as may original types as operands");
221 
222  // Find operands of unranked memref type and store them.
224  SmallVector<unsigned> unrankedAddressSpaces;
225  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
226  if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
227  unrankedMemrefs.emplace_back(operands[i]);
228  FailureOr<unsigned> addressSpace =
230  if (failed(addressSpace))
231  return failure();
232  unrankedAddressSpaces.emplace_back(*addressSpace);
233  }
234  }
235 
236  if (unrankedMemrefs.empty())
237  return success();
238 
239  // Compute allocation sizes.
240  SmallVector<Value> sizes;
242  unrankedMemrefs, unrankedAddressSpaces,
243  sizes);
244 
245  // Get frequently used types.
246  Type indexType = getTypeConverter()->getIndexType();
247 
248  // Find the malloc and free, or declare them if necessary.
249  auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
250  FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
251  if (toDynamic) {
252  mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
253  if (failed(mallocFunc))
254  return failure();
255  }
256  if (!toDynamic) {
257  freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
258  if (failed(freeFunc))
259  return failure();
260  }
261 
262  unsigned unrankedMemrefPos = 0;
263  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
264  Type type = origTypes[i];
265  if (!isa<UnrankedMemRefType>(type))
266  continue;
267  Value allocationSize = sizes[unrankedMemrefPos++];
268  UnrankedMemRefDescriptor desc(operands[i]);
269 
270  // Allocate memory, copy, and free the source if necessary.
271  Value memory =
272  toDynamic
273  ? builder
274  .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
275  .getResult()
276  : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
278  allocationSize,
279  /*alignment=*/0);
280  Value source = desc.memRefDescPtr(builder, loc);
281  builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
282  if (!toDynamic)
283  builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
284 
285  // Create a new descriptor. The same descriptor can be returned multiple
286  // times, attempting to modify its pointer can lead to memory leaks
287  // (allocated twice and overwritten) or double frees (the caller does not
288  // know if the descriptor points to the same memory).
289  Type descriptorType = getTypeConverter()->convertType(type);
290  if (!descriptorType)
291  return failure();
292  auto updatedDesc =
293  UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
294  Value rank = desc.rank(builder, loc);
295  updatedDesc.setRank(builder, loc, rank);
296  updatedDesc.setMemRefDescPtr(builder, loc, memory);
297 
298  operands[i] = updatedDesc;
299  }
300 
301  return success();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // Detail methods
306 //===----------------------------------------------------------------------===//
307 
309  IntegerOverflowFlags overflowFlags) {
310  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
311  iface.setOverflowFlags(overflowFlags);
312 }
313 
314 /// Replaces the given operation "op" with a new operation of type "targetOp"
315 /// and given operands.
317  Operation *op, StringRef targetOp, ValueRange operands,
318  ArrayRef<NamedAttribute> targetAttrs,
319  const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
320  IntegerOverflowFlags overflowFlags) {
321  unsigned numResults = op->getNumResults();
322 
323  SmallVector<Type> resultTypes;
324  if (numResults != 0) {
325  resultTypes.push_back(
326  typeConverter.packOperationResults(op->getResultTypes()));
327  if (!resultTypes.back())
328  return failure();
329  }
330 
331  // Create the operation through state since we don't know its C++ type.
332  Operation *newOp =
333  rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
334  resultTypes, targetAttrs);
335 
336  setNativeProperties(newOp, overflowFlags);
337 
338  // If the operation produced 0 or 1 result, return them immediately.
339  if (numResults == 0)
340  return rewriter.eraseOp(op), success();
341  if (numResults == 1)
342  return rewriter.replaceOp(op, newOp->getResult(0)), success();
343 
344  // Otherwise, it had been converted to an operation producing a structure.
345  // Extract individual results from the structure and return them as list.
346  SmallVector<Value, 4> results;
347  results.reserve(numResults);
348  for (unsigned i = 0; i < numResults; ++i) {
349  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
350  op->getLoc(), newOp->getResult(0), i));
351  }
352  rewriter.replaceOp(op, results);
353  return success();
354 }
355 
357  Operation *op, StringRef intrinsic, ValueRange operands,
358  const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) {
359  auto loc = op->getLoc();
360 
361  if (!llvm::all_of(operands, [](Value value) {
362  return LLVM::isCompatibleType(value.getType());
363  }))
364  return failure();
365 
366  unsigned numResults = op->getNumResults();
367  Type resType;
368  if (numResults != 0)
369  resType = typeConverter.packOperationResults(op->getResultTypes());
370 
371  auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>(
372  loc, resType, rewriter.getStringAttr(intrinsic), operands);
373  // Propagate attributes.
374  callIntrOp->setAttrs(op->getAttrDictionary());
375 
376  if (numResults <= 1) {
377  // Directly replace the original op.
378  rewriter.replaceOp(op, callIntrOp);
379  return success();
380  }
381 
382  // Extract individual results from packed structure and use them as
383  // replacements.
384  SmallVector<Value, 4> results;
385  results.reserve(numResults);
386  Value intrRes = callIntrOp.getResults();
387  for (unsigned i = 0; i < numResults; ++i)
388  results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
389  rewriter.replaceOp(op, results);
390 
391  return success();
392 }
393 
394 static unsigned getBitWidth(Type type) {
395  if (type.isIntOrFloat())
396  return type.getIntOrFloatBitWidth();
397 
398  auto vec = cast<VectorType>(type);
399  assert(!vec.isScalable() && "scalable vectors are not supported");
400  return vec.getNumElements() * getBitWidth(vec.getElementType());
401 }
402 
404  int32_t value) {
405  Type i32 = builder.getI32Type();
406  return builder.create<LLVM::ConstantOp>(loc, i32, value);
407 }
408 
410  Value src, Type dstType) {
411  Type srcType = src.getType();
412  if (srcType == dstType)
413  return {src};
414 
415  unsigned srcBitWidth = getBitWidth(srcType);
416  unsigned dstBitWidth = getBitWidth(dstType);
417  if (srcBitWidth == dstBitWidth) {
418  Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
419  return {cast};
420  }
421 
422  if (dstBitWidth > srcBitWidth) {
423  auto smallerInt = builder.getIntegerType(srcBitWidth);
424  if (srcType != smallerInt)
425  src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
426 
427  auto largerInt = builder.getIntegerType(dstBitWidth);
428  Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
429  return {res};
430  }
431  assert(srcBitWidth % dstBitWidth == 0 &&
432  "src bit width must be a multiple of dst bit width");
433  int64_t numElements = srcBitWidth / dstBitWidth;
434  auto vecType = VectorType::get(numElements, dstType);
435 
436  src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
437 
438  SmallVector<Value> res;
439  for (auto i : llvm::seq(numElements)) {
440  Value idx = createI32Constant(builder, loc, i);
441  Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
442  res.emplace_back(elem);
443  }
444 
445  return res;
446 }
447 
449  Type dstType) {
450  assert(!src.empty() && "src range must not be empty");
451  if (src.size() == 1) {
452  Value res = src.front();
453  if (res.getType() == dstType)
454  return res;
455 
456  unsigned srcBitWidth = getBitWidth(res.getType());
457  unsigned dstBitWidth = getBitWidth(dstType);
458  if (dstBitWidth < srcBitWidth) {
459  auto largerInt = builder.getIntegerType(srcBitWidth);
460  if (res.getType() != largerInt)
461  res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
462 
463  auto smallerInt = builder.getIntegerType(dstBitWidth);
464  res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
465  }
466 
467  if (res.getType() != dstType)
468  res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
469 
470  return res;
471  }
472 
473  int64_t numElements = src.size();
474  auto srcType = VectorType::get(numElements, src.front().getType());
475  Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
476  for (auto &&[i, elem] : llvm::enumerate(src)) {
477  Value idx = createI32Constant(builder, loc, i);
478  res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
479  }
480 
481  if (res.getType() != dstType)
482  res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
483 
484  return res;
485 }
486 
488  const LLVMTypeConverter &converter,
489  MemRefType type, Value memRefDesc,
490  ValueRange indices,
491  LLVM::GEPNoWrapFlags noWrapFlags) {
492  auto [strides, offset] = type.getStridesAndOffset();
493 
494  MemRefDescriptor memRefDescriptor(memRefDesc);
495  // Use a canonical representation of the start address so that later
496  // optimizations have a longer sequence of instructions to CSE.
497  // If we don't do that we would sprinkle the memref.offset in various
498  // position of the different address computations.
499  Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type);
500 
501  LLVM::IntegerOverflowFlags intOverflowFlags =
502  LLVM::IntegerOverflowFlags::none;
503  if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
504  intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
505  }
506  if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
507  intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
508  }
509 
510  Type indexType = converter.getIndexType();
511  Value index;
512  for (int i = 0, e = indices.size(); i < e; ++i) {
513  Value increment = indices[i];
514  if (strides[i] != 1) { // Skip if stride is 1.
515  Value stride =
516  ShapedType::isDynamic(strides[i])
517  ? memRefDescriptor.stride(builder, loc, i)
518  : builder.create<LLVM::ConstantOp>(
519  loc, indexType, builder.getIndexAttr(strides[i]));
520  increment =
521  builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags);
522  }
523  index = index ? builder.create<LLVM::AddOp>(loc, index, increment,
524  intOverflowFlags)
525  : increment;
526  }
527 
528  Type elementPtrType = memRefDescriptor.getElementPtrType();
529  return index ? builder.create<LLVM::GEPOp>(
530  loc, elementPtrType,
531  converter.convertType(type.getElementType()), base, index,
532  noWrapFlags)
533  : base;
534 }
static Value createI32Constant(OpBuilder &builder, Location loc, int32_t value)
Definition: Pattern.cpp:403
static unsigned getBitWidth(Type type)
Definition: Pattern.cpp:394
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
IntegerType getI32Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:187
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition: Pattern.cpp:61
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:85
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition: Pattern.cpp:155
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:140
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:216
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:78
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Definition: Pattern.cpp:71
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:49
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:316
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition: Pattern.cpp:308
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
Definition: Pattern.cpp:356
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:487
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:448
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp)
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:409
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...