MLIR 22.0.0git
Pattern.cpp
Go to the documentation of this file.
1//===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/AffineMap.h"
15
16using namespace mlir;
17
18//===----------------------------------------------------------------------===//
19// ConvertToLLVMPattern
20//===----------------------------------------------------------------------===//
21
23 StringRef rootOpName, MLIRContext *context,
24 const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26
28 return static_cast<const LLVMTypeConverter *>(
29 ConversionPattern::getTypeConverter());
30}
31
32LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33 return *getTypeConverter()->getDialect();
34}
35
39
40Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
41 return IntegerType::get(&getTypeConverter()->getContext(),
42 getTypeConverter()->getPointerBitwidth(addressSpace));
43}
44
46 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
47}
48
49Type ConvertToLLVMPattern::getPtrType(unsigned addressSpace) const {
50 return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext(),
51 addressSpace);
52}
53
55
57 Location loc,
58 Type resultType,
59 int64_t value) {
60 return LLVM::ConstantOp::create(builder, loc, resultType,
61 builder.getIndexAttr(value));
62}
63
65 ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
66 Value memRefDesc, ValueRange indices,
67 LLVM::GEPNoWrapFlags noWrapFlags) const {
68 return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type,
69 memRefDesc, indices, noWrapFlags);
70}
71
72// Check if the MemRefType `type` is supported by the lowering. We currently
73// only support memrefs with identity maps.
75 MemRefType type) const {
76 if (!type.getLayout().isIdentity())
77 return false;
78 return static_cast<bool>(typeConverter->convertType(type));
79}
80
82 auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
83 if (failed(addressSpace))
84 return {};
85 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
86}
87
89 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
90 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
91 SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
92 assert(isConvertibleAndHasIdentityMaps(memRefType) &&
93 "layout maps must have been normalized away");
94 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
95 static_cast<ssize_t>(dynamicSizes.size()) &&
96 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
97
98 sizes.reserve(memRefType.getRank());
99 unsigned dynamicIndex = 0;
100 Type indexType = getIndexType();
101 for (int64_t size : memRefType.getShape()) {
102 sizes.push_back(
103 size == ShapedType::kDynamic
104 ? dynamicSizes[dynamicIndex++]
105 : createIndexAttrConstant(rewriter, loc, indexType, size));
106 }
107
108 // Strides: iterate sizes in reverse order and multiply.
109 int64_t stride = 1;
110 Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
111 strides.resize(memRefType.getRank());
112 for (auto i = memRefType.getRank(); i-- > 0;) {
113 strides[i] = runningStride;
114
115 int64_t staticSize = memRefType.getShape()[i];
116 bool useSizeAsStride = stride == 1;
117 if (staticSize == ShapedType::kDynamic)
118 stride = ShapedType::kDynamic;
119 if (stride != ShapedType::kDynamic)
120 stride *= staticSize;
121
122 if (useSizeAsStride)
123 runningStride = sizes[i];
124 else if (stride == ShapedType::kDynamic)
125 runningStride =
126 LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]);
127 else
128 runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
129 }
130 if (sizeInBytes) {
131 // Buffer size in bytes.
132 Type elementType = typeConverter->convertType(memRefType.getElementType());
133 auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
134 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
135 Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType,
136 elementType, nullPtr, runningStride);
137 size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
138 } else {
139 size = runningStride;
140 }
141}
142
144 Location loc, Type type, ConversionPatternRewriter &rewriter) const {
145 // Compute the size of an individual element. This emits the MLIR equivalent
146 // of the following sizeof(...) implementation in LLVM IR:
147 // %0 = getelementptr %elementType* null, %indexType 1
148 // %1 = ptrtoint %elementType* %0 to %indexType
149 // which is a common pattern of getting the size of a type in bytes.
150 Type llvmType = typeConverter->convertType(type);
151 auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
152 auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType);
153 auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType,
154 nullPtr, ArrayRef<LLVM::GEPArg>{1});
155 return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep);
156}
157
159 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
160 ConversionPatternRewriter &rewriter) const {
161 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
162 static_cast<ssize_t>(dynamicSizes.size()) &&
163 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
164
165 Type indexType = getIndexType();
166 Value numElements = memRefType.getRank() == 0
167 ? createIndexAttrConstant(rewriter, loc, indexType, 1)
168 : nullptr;
169 unsigned dynamicIndex = 0;
170
171 // Compute the total number of memref elements.
172 for (int64_t staticSize : memRefType.getShape()) {
173 if (numElements) {
174 Value size =
175 staticSize == ShapedType::kDynamic
176 ? dynamicSizes[dynamicIndex++]
177 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
178 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
179 } else {
180 numElements =
181 staticSize == ShapedType::kDynamic
182 ? dynamicSizes[dynamicIndex++]
183 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
184 }
185 }
186 return numElements;
187}
188
189/// Creates and populates the memref descriptor struct given all its fields.
191 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
192 ArrayRef<Value> sizes, ArrayRef<Value> strides,
193 ConversionPatternRewriter &rewriter) const {
194 auto structType = typeConverter->convertType(memRefType);
195 auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);
196
197 // Field 1: Allocated pointer, used for malloc/free.
198 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
199
200 // Field 2: Actual aligned pointer to payload.
201 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
202
203 // Field 3: Offset in aligned pointer.
204 Type indexType = getIndexType();
205 memRefDescriptor.setOffset(
206 rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
207
208 // Fields 4: Sizes.
209 for (const auto &en : llvm::enumerate(sizes))
210 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
211
212 // Field 5: Strides.
213 for (const auto &en : llvm::enumerate(strides))
214 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
215
216 return memRefDescriptor;
217}
218
220 OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
221 Value operand, bool toDynamic) const {
222 // Convert memory space.
223 FailureOr<unsigned> addressSpace =
225 if (failed(addressSpace))
226 return {};
227
228 // Get frequently used types.
229 Type indexType = getTypeConverter()->getIndexType();
230
231 // Find the malloc and free, or declare them if necessary.
232 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
233 FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
234 if (toDynamic) {
235 mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
236 if (failed(mallocFunc))
237 return {};
238 }
239 if (!toDynamic) {
240 freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
241 if (failed(freeFunc))
242 return {};
243 }
244
245 UnrankedMemRefDescriptor desc(operand);
247 builder, loc, *getTypeConverter(), desc, *addressSpace);
248
249 // Allocate memory, copy, and free the source if necessary.
250 Value memory = toDynamic
251 ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
252 allocationSize)
253 .getResult()
254 : LLVM::AllocaOp::create(builder, loc, getPtrType(),
255 IntegerType::get(getContext(), 8),
256 allocationSize,
257 /*alignment=*/0);
258 Value source = desc.memRefDescPtr(builder, loc);
259 LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
260 if (!toDynamic)
261 LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
262
263 // Create a new descriptor. The same descriptor can be returned multiple
264 // times, attempting to modify its pointer can lead to memory leaks
265 // (allocated twice and overwritten) or double frees (the caller does not
266 // know if the descriptor points to the same memory).
267 Type descriptorType = getTypeConverter()->convertType(memRefType);
268 if (!descriptorType)
269 return {};
270 auto updatedDesc =
271 UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
272 Value rank = desc.rank(builder, loc);
273 updatedDesc.setRank(builder, loc, rank);
274 updatedDesc.setMemRefDescPtr(builder, loc, memory);
275 return updatedDesc;
276}
277
279 OpBuilder &builder, Location loc, TypeRange origTypes,
280 SmallVectorImpl<Value> &operands, bool toDynamic) const {
281 assert(origTypes.size() == operands.size() &&
282 "expected as may original types as operands");
283 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
284 if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
285 Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
286 operands[i], toDynamic);
287 if (!updatedDesc)
288 return failure();
289 operands[i] = updatedDesc;
290 }
291 }
292 return success();
293}
294
295//===----------------------------------------------------------------------===//
296// Detail methods
297//===----------------------------------------------------------------------===//
298
300 IntegerOverflowFlags overflowFlags) {
301 if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
302 iface.setOverflowFlags(overflowFlags);
303}
304
305/// Replaces the given operation "op" with a new operation of type "targetOp"
306/// and given operands.
308 Operation *op, StringRef targetOp, ValueRange operands,
309 ArrayRef<NamedAttribute> targetAttrs,
310 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
311 IntegerOverflowFlags overflowFlags) {
312 unsigned numResults = op->getNumResults();
313
314 SmallVector<Type> resultTypes;
315 if (numResults != 0) {
316 resultTypes.push_back(
317 typeConverter.packOperationResults(op->getResultTypes()));
318 if (!resultTypes.back())
319 return failure();
320 }
321
322 // Create the operation through state since we don't know its C++ type.
323 Operation *newOp =
324 rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
325 resultTypes, targetAttrs);
326
327 setNativeProperties(newOp, overflowFlags);
328
329 // If the operation produced 0 or 1 result, return them immediately.
330 if (numResults == 0)
331 return rewriter.eraseOp(op), success();
332 if (numResults == 1)
333 return rewriter.replaceOp(op, newOp->getResult(0)), success();
334
335 // Otherwise, it had been converted to an operation producing a structure.
336 // Extract individual results from the structure and return them as list.
337 SmallVector<Value, 4> results;
338 results.reserve(numResults);
339 for (unsigned i = 0; i < numResults; ++i) {
340 results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(),
341 newOp->getResult(0), i));
342 }
343 rewriter.replaceOp(op, results);
344 return success();
345}
346
348 Operation *op, StringRef intrinsic, ValueRange operands,
349 const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) {
350 auto loc = op->getLoc();
351
352 if (!llvm::all_of(operands, [](Value value) {
353 return LLVM::isCompatibleType(value.getType());
354 }))
355 return failure();
356
357 unsigned numResults = op->getNumResults();
358 Type resType;
359 if (numResults != 0)
360 resType = typeConverter.packOperationResults(op->getResultTypes());
361
362 auto callIntrOp = LLVM::CallIntrinsicOp::create(
363 rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands);
364 // Propagate attributes.
365 callIntrOp->setAttrs(op->getAttrDictionary());
366
367 if (numResults <= 1) {
368 // Directly replace the original op.
369 rewriter.replaceOp(op, callIntrOp);
370 return success();
371 }
372
373 // Extract individual results from packed structure and use them as
374 // replacements.
375 SmallVector<Value, 4> results;
376 results.reserve(numResults);
377 Value intrRes = callIntrOp.getResults();
378 for (unsigned i = 0; i < numResults; ++i)
379 results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i));
380 rewriter.replaceOp(op, results);
381
382 return success();
383}
384
385static unsigned getBitWidth(Type type) {
386 if (type.isIntOrFloat())
387 return type.getIntOrFloatBitWidth();
388
389 auto vec = cast<VectorType>(type);
390 assert(!vec.isScalable() && "scalable vectors are not supported");
391 return vec.getNumElements() * getBitWidth(vec.getElementType());
392}
393
395 int32_t value) {
396 Type i32 = builder.getI32Type();
397 return LLVM::ConstantOp::create(builder, loc, i32, value);
398}
399
401 Value src, Type dstType) {
402 Type srcType = src.getType();
403 if (srcType == dstType)
404 return {src};
405
406 unsigned srcBitWidth = getBitWidth(srcType);
407 unsigned dstBitWidth = getBitWidth(dstType);
408 if (srcBitWidth == dstBitWidth) {
409 Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
410 return {cast};
411 }
412
413 if (dstBitWidth > srcBitWidth) {
414 auto smallerInt = builder.getIntegerType(srcBitWidth);
415 if (srcType != smallerInt)
416 src = LLVM::BitcastOp::create(builder, loc, smallerInt, src);
417
418 auto largerInt = builder.getIntegerType(dstBitWidth);
419 Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
420 return {res};
421 }
422 assert(srcBitWidth % dstBitWidth == 0 &&
423 "src bit width must be a multiple of dst bit width");
424 int64_t numElements = srcBitWidth / dstBitWidth;
425 auto vecType = VectorType::get(numElements, dstType);
426
427 src = LLVM::BitcastOp::create(builder, loc, vecType, src);
428
430 for (auto i : llvm::seq(numElements)) {
431 Value idx = createI32Constant(builder, loc, i);
432 Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
433 res.emplace_back(elem);
434 }
435
436 return res;
437}
438
440 Type dstType) {
441 assert(!src.empty() && "src range must not be empty");
442 if (src.size() == 1) {
443 Value res = src.front();
444 if (res.getType() == dstType)
445 return res;
446
447 unsigned srcBitWidth = getBitWidth(res.getType());
448 unsigned dstBitWidth = getBitWidth(dstType);
449 if (dstBitWidth < srcBitWidth) {
450 auto largerInt = builder.getIntegerType(srcBitWidth);
451 if (res.getType() != largerInt)
452 res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
453
454 auto smallerInt = builder.getIntegerType(dstBitWidth);
455 res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
456 }
457
458 if (res.getType() != dstType)
459 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
460
461 return res;
462 }
463
464 int64_t numElements = src.size();
465 auto srcType = VectorType::get(numElements, src.front().getType());
466 Value res = LLVM::PoisonOp::create(builder, loc, srcType);
467 for (auto &&[i, elem] : llvm::enumerate(src)) {
468 Value idx = createI32Constant(builder, loc, i);
469 res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
470 }
471
472 if (res.getType() != dstType)
473 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
474
475 return res;
476}
477
479 const LLVMTypeConverter &converter,
480 MemRefType type, Value memRefDesc,
482 LLVM::GEPNoWrapFlags noWrapFlags) {
483 auto [strides, offset] = type.getStridesAndOffset();
484
485 MemRefDescriptor memRefDescriptor(memRefDesc);
486 // Use a canonical representation of the start address so that later
487 // optimizations have a longer sequence of instructions to CSE.
488 // If we don't do that we would sprinkle the memref.offset in various
489 // position of the different address computations.
490 Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type);
491
492 LLVM::IntegerOverflowFlags intOverflowFlags =
493 LLVM::IntegerOverflowFlags::none;
494 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
495 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
496 }
497 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
498 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
499 }
500
501 Type indexType = converter.getIndexType();
502 Value index;
503 for (int i = 0, e = indices.size(); i < e; ++i) {
504 Value increment = indices[i];
505 if (strides[i] != 1) { // Skip if stride is 1.
506 Value stride =
507 ShapedType::isDynamic(strides[i])
508 ? memRefDescriptor.stride(builder, loc, i)
509 : LLVM::ConstantOp::create(builder, loc, indexType,
510 builder.getIndexAttr(strides[i]));
511 increment = LLVM::MulOp::create(builder, loc, increment, stride,
512 intOverflowFlags);
513 }
514 index = index ? LLVM::AddOp::create(builder, loc, index, increment,
515 intOverflowFlags)
516 : increment;
517 }
518
519 Type elementPtrType = memRefDescriptor.getElementPtrType();
520 return index
521 ? LLVM::GEPOp::create(builder, loc, elementPtrType,
522 converter.convertType(type.getElementType()),
523 base, index, noWrapFlags)
524 : base;
525}
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
return success()
static unsigned getBitWidth(Type type)
Definition Pattern.cpp:385
b getContext())
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition Pattern.cpp:45
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition Pattern.cpp:190
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.cpp:22
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition Pattern.cpp:88
Type getPtrType(unsigned addressSpace=0) const
Get the MLIR type wrapping the LLVM ptr type.
Definition Pattern.cpp:49
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition Pattern.cpp:36
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition Pattern.cpp:158
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition Pattern.cpp:32
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition Pattern.cpp:143
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition Pattern.cpp:40
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const
Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to sta...
Definition Pattern.cpp:219
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition Pattern.cpp:278
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition Pattern.cpp:81
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition Pattern.cpp:56
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Definition Pattern.cpp:74
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition Pattern.cpp:54
Conversion from types to the LLVM IR dialect.
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_type_range getResultTypes()
Definition Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:307
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition Pattern.cpp:299
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
Definition Pattern.cpp:347
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:478
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:439
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:400
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.