29 ConversionPattern::getTypeConverter());
60 return LLVM::ConstantOp::create(builder, loc, resultType,
65 ConversionPatternRewriter &rewriter,
Location loc, MemRefType type,
67 LLVM::GEPNoWrapFlags noWrapFlags)
const {
69 memRefDesc,
indices, noWrapFlags);
75 MemRefType type)
const {
76 if (!type.getLayout().isIdentity())
78 return static_cast<bool>(typeConverter->convertType(type));
83 if (failed(addressSpace))
85 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
93 "layout maps must have been normalized away");
94 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
95 static_cast<ssize_t
>(dynamicSizes.size()) &&
96 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
98 sizes.reserve(memRefType.getRank());
99 unsigned dynamicIndex = 0;
101 for (
int64_t size : memRefType.getShape()) {
103 size == ShapedType::kDynamic
104 ? dynamicSizes[dynamicIndex++]
111 strides.resize(memRefType.getRank());
112 for (
auto i = memRefType.getRank(); i-- > 0;) {
113 strides[i] = runningStride;
115 int64_t staticSize = memRefType.getShape()[i];
116 bool useSizeAsStride = stride == 1;
117 if (staticSize == ShapedType::kDynamic)
118 stride = ShapedType::kDynamic;
119 if (stride != ShapedType::kDynamic)
120 stride *= staticSize;
123 runningStride = sizes[i];
124 else if (stride == ShapedType::kDynamic)
126 LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]);
132 Type elementType = typeConverter->convertType(memRefType.getElementType());
133 auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
134 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
135 Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType,
136 elementType, nullPtr, runningStride);
137 size = LLVM::PtrToIntOp::create(rewriter, loc,
getIndexType(), gepPtr);
139 size = runningStride;
144 Location loc,
Type type, ConversionPatternRewriter &rewriter)
const {
150 Type llvmType = typeConverter->convertType(type);
151 auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
152 auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType);
153 auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType,
155 return LLVM::PtrToIntOp::create(rewriter, loc,
getIndexType(), gep);
160 ConversionPatternRewriter &rewriter)
const {
161 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
162 static_cast<ssize_t
>(dynamicSizes.size()) &&
163 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
166 Value numElements = memRefType.getRank() == 0
169 unsigned dynamicIndex = 0;
172 for (
int64_t staticSize : memRefType.getShape()) {
175 staticSize == ShapedType::kDynamic
176 ? dynamicSizes[dynamicIndex++]
178 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
181 staticSize == ShapedType::kDynamic
182 ? dynamicSizes[dynamicIndex++]
193 ConversionPatternRewriter &rewriter)
const {
194 auto structType = typeConverter->convertType(memRefType);
198 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
201 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
205 memRefDescriptor.setOffset(
209 for (
const auto &en : llvm::enumerate(sizes))
210 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
213 for (
const auto &en : llvm::enumerate(strides))
214 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
216 return memRefDescriptor;
221 Value operand,
bool toDynamic)
const {
223 FailureOr<unsigned> addressSpace =
225 if (failed(addressSpace))
232 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
233 FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
236 if (failed(mallocFunc))
241 if (failed(freeFunc))
250 Value memory = toDynamic
251 ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
254 : LLVM::AllocaOp::create(builder, loc,
getPtrType(),
259 LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize,
false);
261 LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
273 updatedDesc.setRank(builder, loc, rank);
274 updatedDesc.setMemRefDescPtr(builder, loc, memory);
281 assert(origTypes.size() == operands.size() &&
282 "expected as may original types as operands");
283 for (
unsigned i = 0, e = operands.size(); i < e; ++i) {
284 if (
auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
286 operands[i], toDynamic);
289 operands[i] = updatedDesc;
305 ConversionPatternRewriter &rewriter) {
309 if (numResults != 0) {
310 resultTypes.push_back(
312 if (!resultTypes.back())
318 resultTypes, targetAttrs);
320 Operation *newOp = rewriter.create(state);
324 return rewriter.eraseOp(op),
success();
331 results.reserve(numResults);
332 for (
unsigned i = 0; i < numResults; ++i) {
333 results.push_back(LLVM::ExtractValueOp::create(rewriter, op->
getLoc(),
336 rewriter.replaceOp(op, results);
345 if (!llvm::all_of(operands, [](
Value value) {
355 auto callIntrOp = LLVM::CallIntrinsicOp::create(
356 rewriter, loc, resType, rewriter.
getStringAttr(intrinsic), operands);
360 if (numResults <= 1) {
369 results.reserve(numResults);
370 Value intrRes = callIntrOp.getResults();
371 for (
unsigned i = 0; i < numResults; ++i)
372 results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i));
382 auto vec = cast<VectorType>(type);
383 assert(!vec.isScalable() &&
"scalable vectors are not supported");
384 return vec.getNumElements() *
getBitWidth(vec.getElementType());
390 return LLVM::ConstantOp::create(builder, loc, i32, value);
396 if (srcType == dstType)
401 if (srcBitWidth == dstBitWidth) {
402 Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
406 if (dstBitWidth > srcBitWidth) {
408 if (srcType != smallerInt)
409 src = LLVM::BitcastOp::create(builder, loc, smallerInt, src);
412 Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
415 assert(srcBitWidth % dstBitWidth == 0 &&
416 "src bit width must be a multiple of dst bit width");
417 int64_t numElements = srcBitWidth / dstBitWidth;
418 auto vecType = VectorType::get(numElements, dstType);
420 src = LLVM::BitcastOp::create(builder, loc, vecType, src);
423 for (
auto i : llvm::seq(numElements)) {
425 Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
426 res.emplace_back(elem);
434 assert(!src.empty() &&
"src range must not be empty");
435 if (src.size() == 1) {
436 Value res = src.front();
442 if (dstBitWidth < srcBitWidth) {
444 if (res.
getType() != largerInt)
445 res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
448 res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
452 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
457 int64_t numElements = src.size();
458 auto srcType = VectorType::get(numElements, src.front().
getType());
459 Value res = LLVM::PoisonOp::create(builder, loc, srcType);
460 for (
auto &&[i, elem] : llvm::enumerate(src)) {
462 res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
466 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
473 MemRefType type,
Value memRefDesc,
475 LLVM::GEPNoWrapFlags noWrapFlags) {
476 auto [strides, offset] = type.getStridesAndOffset();
483 Value base = memRefDescriptor.
bufferPtr(builder, loc, converter, type);
485 LLVM::IntegerOverflowFlags intOverflowFlags =
486 LLVM::IntegerOverflowFlags::none;
487 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
488 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
490 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
491 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
496 for (
int i = 0, e =
indices.size(); i < e; ++i) {
498 if (strides[i] != 1) {
500 ShapedType::isDynamic(strides[i])
501 ? memRefDescriptor.
stride(builder, loc, i)
502 : LLVM::ConstantOp::create(builder, loc, indexType,
504 increment = LLVM::MulOp::create(builder, loc, increment, stride,
514 ? LLVM::GEPOp::create(builder, loc, elementPtrType,
515 converter.convertType(type.getElementType()),
516 base,
index, noWrapFlags)
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static unsigned getBitWidth(Type type)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Type getPtrType(unsigned addressSpace=0) const
Get the MLIR type wrapping the LLVM ptr type.
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
const LLVMTypeConverter * getTypeConverter() const
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const
Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to sta...
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Conversion from types to the LLVM IR dialect.
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.