20 #include <type_traits>
31 matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
38 struct GetMetadataOpConversion
42 matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
52 matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
62 matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
69 struct TypeOffsetOpConversion
73 matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
83 static FailureOr<LLVM::LLVMStructType>
98 if (
failed(type.getStridesAndOffset(strides, offset)))
114 elements.push_back(ptrType);
117 if (offset == ShapedType::kDynamic)
118 elements.push_back(indexType);
121 for (int64_t dim : shape) {
122 if (dim == ShapedType::kDynamic)
123 elements.push_back(indexType);
127 for (int64_t stride : strides) {
128 if (stride == ShapedType::kDynamic)
129 elements.push_back(indexType);
131 return LLVM::LLVMStructType::getLiteral(context, elements);
138 LogicalResult FromPtrOpConversion::matchAndRewrite(
139 ptr::FromPtrOp op, OpAdaptor adaptor,
142 auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
146 if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
148 op,
"Can convert only memrefs with metadata");
152 Type descriptorTy = getTypeConverter()->convertType(mTy);
159 if (
failed(mTy.getStridesAndOffset(strides, offset))) {
161 "Failed to get the strides and offset");
170 desc.setAllocatedPtr(
172 LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getMetadata(), 0));
173 desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
176 unsigned fieldIdx = 1;
179 if (offset == ShapedType::kDynamic) {
180 Value offsetValue = LLVM::ExtractValueOp::create(
181 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
182 desc.setOffset(rewriter, loc, offsetValue);
184 desc.setConstantOffset(rewriter, loc, offset);
189 if (dim == ShapedType::kDynamic) {
190 Value sizeValue = LLVM::ExtractValueOp::create(
191 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
192 desc.setSize(rewriter, loc, i, sizeValue);
194 desc.setConstantSize(rewriter, loc, i, dim);
200 if (stride == ShapedType::kDynamic) {
201 Value strideValue = LLVM::ExtractValueOp::create(
202 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
203 desc.setStride(rewriter, loc, i, strideValue);
205 desc.setConstantStride(rewriter, loc, i, stride);
217 LogicalResult GetMetadataOpConversion::matchAndRewrite(
218 ptr::GetMetadataOp op, OpAdaptor adaptor,
220 auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
225 FailureOr<LLVM::LLVMStructType> mdTy =
229 "Failed to create the metadata type");
238 if (
failed(mTy.getStridesAndOffset(strides, offset))) {
240 "Failed to get the strides and offset");
246 Value sV = LLVM::UndefOp::create(rewriter, loc, *mdTy);
250 sV = LLVM::InsertValueOp::create(rewriter, loc, sV,
251 descriptor.allocatedPtr(rewriter, loc), pos);
254 unsigned fieldIdx = 1;
257 if (offset == ShapedType::kDynamic) {
258 sV = LLVM::InsertValueOp::create(
259 rewriter, loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
264 if (dim != ShapedType::kDynamic)
266 sV = LLVM::InsertValueOp::create(
267 rewriter, loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
272 if (stride != ShapedType::kDynamic)
274 sV = LLVM::InsertValueOp::create(
275 rewriter, loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
286 PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
289 Value base = adaptor.getBase();
290 if (!isa<LLVM::LLVMPointerType>(base.
getType()))
294 Value offset = adaptor.getOffset();
300 LLVM::GEPNoWrapFlags flags;
301 switch (op.getFlags()) {
302 case ptr::PtrAddFlags::none:
303 flags = LLVM::GEPNoWrapFlags::none;
305 case ptr::PtrAddFlags::nusw:
306 flags = LLVM::GEPNoWrapFlags::nusw;
308 case ptr::PtrAddFlags::nuw:
309 flags = LLVM::GEPNoWrapFlags::nuw;
311 case ptr::PtrAddFlags::inbounds:
312 flags = LLVM::GEPNoWrapFlags::inbounds;
327 ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
330 if (!isa<MemRefType>(op.getPtr().getType()))
343 LogicalResult TypeOffsetOpConversion::matchAndRewrite(
344 ptr::TypeOffsetOp op, OpAdaptor adaptor,
347 Type type = getTypeConverter()->convertType(op.getElementType());
352 Type rTy = getTypeConverter()->convertType(op.getResult().getType());
364 LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
365 LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
381 void loadDependentDialects(
MLIRContext *context)
const final {
382 context->loadDialect<LLVM::LLVMDialect>();
387 void populateConvertToLLVMConversionPatterns(
403 [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
405 if (type.getMemorySpace() != memorySpace)
412 std::optional<Attribute> maybeAttr =
415 maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
419 memSpace.getValue().getSExtValue());
424 auto mTy = dyn_cast<MemRefType>(type.getType());
427 FailureOr<LLVM::LLVMStructType> res =
433 patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
434 ToPtrOpConversion, TypeOffsetOpConversion>(converter);
439 dialect->addInterfaces<PtrToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
static FailureOr< LLVM::LLVMStructType > createMemRefMetadataType(MemRefType type, const LLVMTypeConverter &typeConverter)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult na()
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the convert to LLVM patterns for the ptr dialect.
void registerConvertPtrToLLVMInterface(DialectRegistry ®istry)
Register the convert to LLVM interface for the ptr dialect.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...