MLIR  22.0.0git
Pattern.cpp
Go to the documentation of this file.
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/AffineMap.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
21 
23  StringRef rootOpName, MLIRContext *context,
24  const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25  : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26 
28  return static_cast<const LLVMTypeConverter *>(
30 }
31 
32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33  return *getTypeConverter()->getDialect();
34 }
35 
37  return getTypeConverter()->getIndexType();
38 }
39 
40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
42  getTypeConverter()->getPointerBitwidth(addressSpace));
43 }
44 
47 }
48 
49 Type ConvertToLLVMPattern::getPtrType(unsigned addressSpace) const {
51  addressSpace);
52 }
53 
55 
57  Location loc,
58  Type resultType,
59  int64_t value) {
60  return LLVM::ConstantOp::create(builder, loc, resultType,
61  builder.getIndexAttr(value));
62 }
63 
65  ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
66  Value memRefDesc, ValueRange indices,
67  LLVM::GEPNoWrapFlags noWrapFlags) const {
68  return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type,
69  memRefDesc, indices, noWrapFlags);
70 }
71 
72 // Check if the MemRefType `type` is supported by the lowering. We currently
73 // only support memrefs with identity maps.
75  MemRefType type) const {
76  if (!type.getLayout().isIdentity())
77  return false;
78  return static_cast<bool>(typeConverter->convertType(type));
79 }
80 
82  auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
83  if (failed(addressSpace))
84  return {};
85  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
86 }
87 
89  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
91  SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
92  assert(isConvertibleAndHasIdentityMaps(memRefType) &&
93  "layout maps must have been normalized away");
94  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
95  static_cast<ssize_t>(dynamicSizes.size()) &&
96  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
97 
98  sizes.reserve(memRefType.getRank());
99  unsigned dynamicIndex = 0;
100  Type indexType = getIndexType();
101  for (int64_t size : memRefType.getShape()) {
102  sizes.push_back(
103  size == ShapedType::kDynamic
104  ? dynamicSizes[dynamicIndex++]
105  : createIndexAttrConstant(rewriter, loc, indexType, size));
106  }
107 
108  // Strides: iterate sizes in reverse order and multiply.
109  int64_t stride = 1;
110  Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
111  strides.resize(memRefType.getRank());
112  for (auto i = memRefType.getRank(); i-- > 0;) {
113  strides[i] = runningStride;
114 
115  int64_t staticSize = memRefType.getShape()[i];
116  bool useSizeAsStride = stride == 1;
117  if (staticSize == ShapedType::kDynamic)
118  stride = ShapedType::kDynamic;
119  if (stride != ShapedType::kDynamic)
120  stride *= staticSize;
121 
122  if (useSizeAsStride)
123  runningStride = sizes[i];
124  else if (stride == ShapedType::kDynamic)
125  runningStride =
126  LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]);
127  else
128  runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
129  }
130  if (sizeInBytes) {
131  // Buffer size in bytes.
132  Type elementType = typeConverter->convertType(memRefType.getElementType());
133  auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
134  Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
135  Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType,
136  elementType, nullPtr, runningStride);
137  size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
138  } else {
139  size = runningStride;
140  }
141 }
142 
144  Location loc, Type type, ConversionPatternRewriter &rewriter) const {
145  // Compute the size of an individual element. This emits the MLIR equivalent
146  // of the following sizeof(...) implementation in LLVM IR:
147  // %0 = getelementptr %elementType* null, %indexType 1
148  // %1 = ptrtoint %elementType* %0 to %indexType
149  // which is a common pattern of getting the size of a type in bytes.
150  Type llvmType = typeConverter->convertType(type);
151  auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
152  auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType);
153  auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType,
154  nullPtr, ArrayRef<LLVM::GEPArg>{1});
155  return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep);
156 }
157 
159  Location loc, MemRefType memRefType, ValueRange dynamicSizes,
160  ConversionPatternRewriter &rewriter) const {
161  assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
162  static_cast<ssize_t>(dynamicSizes.size()) &&
163  "dynamicSizes size doesn't match dynamic sizes count in memref shape");
164 
165  Type indexType = getIndexType();
166  Value numElements = memRefType.getRank() == 0
167  ? createIndexAttrConstant(rewriter, loc, indexType, 1)
168  : nullptr;
169  unsigned dynamicIndex = 0;
170 
171  // Compute the total number of memref elements.
172  for (int64_t staticSize : memRefType.getShape()) {
173  if (numElements) {
174  Value size =
175  staticSize == ShapedType::kDynamic
176  ? dynamicSizes[dynamicIndex++]
177  : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
178  numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
179  } else {
180  numElements =
181  staticSize == ShapedType::kDynamic
182  ? dynamicSizes[dynamicIndex++]
183  : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
184  }
185  }
186  return numElements;
187 }
188 
189 /// Creates and populates the memref descriptor struct given all its fields.
191  Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
192  ArrayRef<Value> sizes, ArrayRef<Value> strides,
193  ConversionPatternRewriter &rewriter) const {
194  auto structType = typeConverter->convertType(memRefType);
195  auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);
196 
197  // Field 1: Allocated pointer, used for malloc/free.
198  memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
199 
200  // Field 2: Actual aligned pointer to payload.
201  memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
202 
203  // Field 3: Offset in aligned pointer.
204  Type indexType = getIndexType();
205  memRefDescriptor.setOffset(
206  rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
207 
208  // Fields 4: Sizes.
209  for (const auto &en : llvm::enumerate(sizes))
210  memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
211 
212  // Field 5: Strides.
213  for (const auto &en : llvm::enumerate(strides))
214  memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
215 
216  return memRefDescriptor;
217 }
218 
220  OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
221  Value operand, bool toDynamic) const {
222  // Convert memory space.
223  FailureOr<unsigned> addressSpace =
225  if (failed(addressSpace))
226  return {};
227 
228  // Get frequently used types.
229  Type indexType = getTypeConverter()->getIndexType();
230 
231  // Find the malloc and free, or declare them if necessary.
232  auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
233  FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
234  if (toDynamic) {
235  mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
236  if (failed(mallocFunc))
237  return {};
238  }
239  if (!toDynamic) {
240  freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
241  if (failed(freeFunc))
242  return {};
243  }
244 
245  UnrankedMemRefDescriptor desc(operand);
247  builder, loc, *getTypeConverter(), desc, *addressSpace);
248 
249  // Allocate memory, copy, and free the source if necessary.
250  Value memory = toDynamic
251  ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
252  allocationSize)
253  .getResult()
254  : LLVM::AllocaOp::create(builder, loc, getPtrType(),
256  allocationSize,
257  /*alignment=*/0);
258  Value source = desc.memRefDescPtr(builder, loc);
259  LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
260  if (!toDynamic)
261  LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
262 
263  // Create a new descriptor. The same descriptor can be returned multiple
264  // times, attempting to modify its pointer can lead to memory leaks
265  // (allocated twice and overwritten) or double frees (the caller does not
266  // know if the descriptor points to the same memory).
267  Type descriptorType = getTypeConverter()->convertType(memRefType);
268  if (!descriptorType)
269  return {};
270  auto updatedDesc =
271  UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
272  Value rank = desc.rank(builder, loc);
273  updatedDesc.setRank(builder, loc, rank);
274  updatedDesc.setMemRefDescPtr(builder, loc, memory);
275  return updatedDesc;
276 }
277 
279  OpBuilder &builder, Location loc, TypeRange origTypes,
280  SmallVectorImpl<Value> &operands, bool toDynamic) const {
281  assert(origTypes.size() == operands.size() &&
282  "expected as may original types as operands");
283  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
284  if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
285  Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
286  operands[i], toDynamic);
287  if (!updatedDesc)
288  return failure();
289  operands[i] = updatedDesc;
290  }
291  }
292  return success();
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Detail methods
297 //===----------------------------------------------------------------------===//
298 
300  IntegerOverflowFlags overflowFlags) {
301  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
302  iface.setOverflowFlags(overflowFlags);
303 }
304 
305 /// Replaces the given operation "op" with a new operation of type "targetOp"
306 /// and given operands.
308  Operation *op, StringRef targetOp, ValueRange operands,
309  ArrayRef<NamedAttribute> targetAttrs,
310  const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
311  IntegerOverflowFlags overflowFlags) {
312  unsigned numResults = op->getNumResults();
313 
314  SmallVector<Type> resultTypes;
315  if (numResults != 0) {
316  resultTypes.push_back(
317  typeConverter.packOperationResults(op->getResultTypes()));
318  if (!resultTypes.back())
319  return failure();
320  }
321 
322  // Create the operation through state since we don't know its C++ type.
323  Operation *newOp =
324  rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
325  resultTypes, targetAttrs);
326 
327  setNativeProperties(newOp, overflowFlags);
328 
329  // If the operation produced 0 or 1 result, return them immediately.
330  if (numResults == 0)
331  return rewriter.eraseOp(op), success();
332  if (numResults == 1)
333  return rewriter.replaceOp(op, newOp->getResult(0)), success();
334 
335  // Otherwise, it had been converted to an operation producing a structure.
336  // Extract individual results from the structure and return them as list.
337  SmallVector<Value, 4> results;
338  results.reserve(numResults);
339  for (unsigned i = 0; i < numResults; ++i) {
340  results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(),
341  newOp->getResult(0), i));
342  }
343  rewriter.replaceOp(op, results);
344  return success();
345 }
346 
348  Operation *op, StringRef intrinsic, ValueRange operands,
349  const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) {
350  auto loc = op->getLoc();
351 
352  if (!llvm::all_of(operands, [](Value value) {
353  return LLVM::isCompatibleType(value.getType());
354  }))
355  return failure();
356 
357  unsigned numResults = op->getNumResults();
358  Type resType;
359  if (numResults != 0)
360  resType = typeConverter.packOperationResults(op->getResultTypes());
361 
362  auto callIntrOp = LLVM::CallIntrinsicOp::create(
363  rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands);
364  // Propagate attributes.
365  callIntrOp->setAttrs(op->getAttrDictionary());
366 
367  if (numResults <= 1) {
368  // Directly replace the original op.
369  rewriter.replaceOp(op, callIntrOp);
370  return success();
371  }
372 
373  // Extract individual results from packed structure and use them as
374  // replacements.
375  SmallVector<Value, 4> results;
376  results.reserve(numResults);
377  Value intrRes = callIntrOp.getResults();
378  for (unsigned i = 0; i < numResults; ++i)
379  results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i));
380  rewriter.replaceOp(op, results);
381 
382  return success();
383 }
384 
385 static unsigned getBitWidth(Type type) {
386  if (type.isIntOrFloat())
387  return type.getIntOrFloatBitWidth();
388 
389  auto vec = cast<VectorType>(type);
390  assert(!vec.isScalable() && "scalable vectors are not supported");
391  return vec.getNumElements() * getBitWidth(vec.getElementType());
392 }
393 
395  int32_t value) {
396  Type i32 = builder.getI32Type();
397  return LLVM::ConstantOp::create(builder, loc, i32, value);
398 }
399 
401  Value src, Type dstType) {
402  Type srcType = src.getType();
403  if (srcType == dstType)
404  return {src};
405 
406  unsigned srcBitWidth = getBitWidth(srcType);
407  unsigned dstBitWidth = getBitWidth(dstType);
408  if (srcBitWidth == dstBitWidth) {
409  Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
410  return {cast};
411  }
412 
413  if (dstBitWidth > srcBitWidth) {
414  auto smallerInt = builder.getIntegerType(srcBitWidth);
415  if (srcType != smallerInt)
416  src = LLVM::BitcastOp::create(builder, loc, smallerInt, src);
417 
418  auto largerInt = builder.getIntegerType(dstBitWidth);
419  Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
420  return {res};
421  }
422  assert(srcBitWidth % dstBitWidth == 0 &&
423  "src bit width must be a multiple of dst bit width");
424  int64_t numElements = srcBitWidth / dstBitWidth;
425  auto vecType = VectorType::get(numElements, dstType);
426 
427  src = LLVM::BitcastOp::create(builder, loc, vecType, src);
428 
429  SmallVector<Value> res;
430  for (auto i : llvm::seq(numElements)) {
431  Value idx = createI32Constant(builder, loc, i);
432  Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
433  res.emplace_back(elem);
434  }
435 
436  return res;
437 }
438 
440  Type dstType) {
441  assert(!src.empty() && "src range must not be empty");
442  if (src.size() == 1) {
443  Value res = src.front();
444  if (res.getType() == dstType)
445  return res;
446 
447  unsigned srcBitWidth = getBitWidth(res.getType());
448  unsigned dstBitWidth = getBitWidth(dstType);
449  if (dstBitWidth < srcBitWidth) {
450  auto largerInt = builder.getIntegerType(srcBitWidth);
451  if (res.getType() != largerInt)
452  res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
453 
454  auto smallerInt = builder.getIntegerType(dstBitWidth);
455  res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
456  }
457 
458  if (res.getType() != dstType)
459  res = LLVM::BitcastOp::create(builder, loc, dstType, res);
460 
461  return res;
462  }
463 
464  int64_t numElements = src.size();
465  auto srcType = VectorType::get(numElements, src.front().getType());
466  Value res = LLVM::PoisonOp::create(builder, loc, srcType);
467  for (auto &&[i, elem] : llvm::enumerate(src)) {
468  Value idx = createI32Constant(builder, loc, i);
469  res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
470  }
471 
472  if (res.getType() != dstType)
473  res = LLVM::BitcastOp::create(builder, loc, dstType, res);
474 
475  return res;
476 }
477 
479  const LLVMTypeConverter &converter,
480  MemRefType type, Value memRefDesc,
481  ValueRange indices,
482  LLVM::GEPNoWrapFlags noWrapFlags) {
483  auto [strides, offset] = type.getStridesAndOffset();
484 
485  MemRefDescriptor memRefDescriptor(memRefDesc);
486  // Use a canonical representation of the start address so that later
487  // optimizations have a longer sequence of instructions to CSE.
488  // If we don't do that we would sprinkle the memref.offset in various
489  // position of the different address computations.
490  Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type);
491 
492  LLVM::IntegerOverflowFlags intOverflowFlags =
493  LLVM::IntegerOverflowFlags::none;
494  if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
495  intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
496  }
497  if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
498  intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
499  }
500 
501  Type indexType = converter.getIndexType();
502  Value index;
503  for (int i = 0, e = indices.size(); i < e; ++i) {
504  Value increment = indices[i];
505  if (strides[i] != 1) { // Skip if stride is 1.
506  Value stride =
507  ShapedType::isDynamic(strides[i])
508  ? memRefDescriptor.stride(builder, loc, i)
509  : LLVM::ConstantOp::create(builder, loc, indexType,
510  builder.getIndexAttr(strides[i]));
511  increment = LLVM::MulOp::create(builder, loc, increment, stride,
512  intOverflowFlags);
513  }
514  index = index ? LLVM::AddOp::create(builder, loc, index, increment,
515  intOverflowFlags)
516  : increment;
517  }
518 
519  Type elementPtrType = memRefDescriptor.getElementPtrType();
520  return index
521  ? LLVM::GEPOp::create(builder, loc, elementPtrType,
522  converter.convertType(type.getElementType()),
523  base, index, noWrapFlags)
524  : base;
525 }
static Value createI32Constant(OpBuilder &builder, Location loc, int32_t value)
Definition: Pattern.cpp:394
static unsigned getBitWidth(Type type)
Definition: Pattern.cpp:385
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Base class for the conversion patterns.
const TypeConverter * typeConverter
An optional type converter for use by this pattern.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition: Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition: Pattern.cpp:190
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.cpp:22
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition: Pattern.cpp:64
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition: Pattern.cpp:88
Type getPtrType(unsigned addressSpace=0) const
Get the MLIR type wrapping the LLVM ptr type.
Definition: Pattern.cpp:49
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition: Pattern.cpp:158
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition: Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition: Pattern.cpp:143
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition: Pattern.cpp:40
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const
Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to sta...
Definition: Pattern.cpp:219
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition: Pattern.cpp:278
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition: Pattern.cpp:81
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:56
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Definition: Pattern.cpp:74
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition: Pattern.cpp:54
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:295
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:134
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:307
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition: Pattern.cpp:299
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
Definition: Pattern.cpp:347
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:439
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:400
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...