31 matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
32 ConversionPatternRewriter &rewriter)
const override;
38struct GetMetadataOpConversion
42 matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
43 ConversionPatternRewriter &rewriter)
const override;
52 matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter)
const override;
62 matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter)
const override;
69struct TypeOffsetOpConversion
73 matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
74 ConversionPatternRewriter &rewriter)
const override;
83static FailureOr<LLVM::LLVMStructType>
89 if (failed(addressSpace))
93 auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace);
98 if (failed(type.getStridesAndOffset(strides, offset)))
114 elements.push_back(ptrType);
117 if (offset == ShapedType::kDynamic)
118 elements.push_back(indexType);
122 if (dim == ShapedType::kDynamic)
123 elements.push_back(indexType);
127 for (
int64_t stride : strides) {
128 if (stride == ShapedType::kDynamic)
129 elements.push_back(indexType);
131 return LLVM::LLVMStructType::getLiteral(context, elements);
138LogicalResult FromPtrOpConversion::matchAndRewrite(
139 ptr::FromPtrOp op, OpAdaptor adaptor,
140 ConversionPatternRewriter &rewriter)
const {
142 auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
144 return rewriter.notifyMatchFailure(op,
"Expected memref result type");
146 if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
147 return rewriter.notifyMatchFailure(
148 op,
"Can convert only memrefs with metadata");
152 Type descriptorTy = getTypeConverter()->convertType(mTy);
154 return rewriter.notifyMatchFailure(op,
"Failed to convert result type");
157 SmallVector<int64_t> strides;
159 if (
failed(mTy.getStridesAndOffset(strides, offset))) {
160 return rewriter.notifyMatchFailure(op,
161 "Failed to get the strides and offset");
163 ArrayRef<int64_t> shape = mTy.getShape();
166 Location loc = op.getLoc();
167 auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
170 desc.setAllocatedPtr(
172 LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getMetadata(), 0));
173 desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
176 unsigned fieldIdx = 1;
179 if (offset == ShapedType::kDynamic) {
180 Value offsetValue = LLVM::ExtractValueOp::create(
181 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
182 desc.setOffset(rewriter, loc, offsetValue);
184 desc.setConstantOffset(rewriter, loc, offset);
188 for (
auto [i, dim] : llvm::enumerate(shape)) {
189 if (dim == ShapedType::kDynamic) {
190 Value sizeValue = LLVM::ExtractValueOp::create(
191 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
192 desc.setSize(rewriter, loc, i, sizeValue);
194 desc.setConstantSize(rewriter, loc, i, dim);
199 for (
auto [i, stride] : llvm::enumerate(strides)) {
200 if (stride == ShapedType::kDynamic) {
201 Value strideValue = LLVM::ExtractValueOp::create(
202 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
203 desc.setStride(rewriter, loc, i, strideValue);
205 desc.setConstantStride(rewriter, loc, i, stride);
209 rewriter.replaceOp(op,
static_cast<Value
>(desc));
217LogicalResult GetMetadataOpConversion::matchAndRewrite(
218 ptr::GetMetadataOp op, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter)
const {
220 auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
222 return rewriter.notifyMatchFailure(op,
"Only memref metadata is supported");
225 FailureOr<LLVM::LLVMStructType> mdTy =
228 return rewriter.notifyMatchFailure(op,
229 "Failed to create the metadata type");
233 MemRefDescriptor descriptor(adaptor.getPtr());
236 SmallVector<int64_t> strides;
238 if (
failed(mTy.getStridesAndOffset(strides, offset))) {
239 return rewriter.notifyMatchFailure(op,
240 "Failed to get the strides and offset");
242 ArrayRef<int64_t> shape = mTy.getShape();
245 Location loc = op.getLoc();
246 Value sV = LLVM::UndefOp::create(rewriter, loc, *mdTy);
249 SmallVector<int64_t> pos{0};
250 sV = LLVM::InsertValueOp::create(rewriter, loc, sV,
251 descriptor.allocatedPtr(rewriter, loc), pos);
254 unsigned fieldIdx = 1;
257 if (offset == ShapedType::kDynamic) {
258 sV = LLVM::InsertValueOp::create(
259 rewriter, loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
263 for (
auto [i, dim] : llvm::enumerate(shape)) {
264 if (dim != ShapedType::kDynamic)
266 sV = LLVM::InsertValueOp::create(
267 rewriter, loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
271 for (
auto [i, stride] : llvm::enumerate(strides)) {
272 if (stride != ShapedType::kDynamic)
274 sV = LLVM::InsertValueOp::create(
275 rewriter, loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
277 rewriter.replaceOp(op, sV);
286PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter)
const {
289 Value base = adaptor.getBase();
290 if (!isa<LLVM::LLVMPointerType>(base.
getType()))
291 return rewriter.notifyMatchFailure(op,
"Incompatible pointer type");
294 Value offset = adaptor.getOffset();
297 Type elementType = IntegerType::get(rewriter.getContext(), 8);
300 LLVM::GEPNoWrapFlags flags;
301 switch (op.getFlags()) {
302 case ptr::PtrAddFlags::none:
303 flags = LLVM::GEPNoWrapFlags::none;
305 case ptr::PtrAddFlags::nusw:
306 flags = LLVM::GEPNoWrapFlags::nusw;
308 case ptr::PtrAddFlags::nuw:
309 flags = LLVM::GEPNoWrapFlags::nuw;
311 case ptr::PtrAddFlags::inbounds:
312 flags = LLVM::GEPNoWrapFlags::inbounds;
317 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.
getType(), elementType,
327ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter)
const {
330 if (!isa<MemRefType>(op.getPtr().getType()))
331 return rewriter.notifyMatchFailure(op,
"Expected a memref input");
335 op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc()));
343LogicalResult TypeOffsetOpConversion::matchAndRewrite(
344 ptr::TypeOffsetOp op, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter)
const {
347 Type type = getTypeConverter()->convertType(op.getElementType());
349 return rewriter.notifyMatchFailure(op,
"Couldn't convert the type");
352 Type rTy = getTypeConverter()->convertType(op.getResult().getType());
354 return rewriter.notifyMatchFailure(op,
"Couldn't convert the result type");
360 auto ptrTy = LLVM::LLVMPointerType::get(
getContext());
364 LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
365 LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
366 ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)}));
369 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes());
379struct PtrToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
380 PtrToLLVMDialectInterface(Dialect *dialect)
381 : ConvertToLLVMPatternInterface(dialect) {}
383 void loadDependentDialects(MLIRContext *context)
const final {
384 context->loadDialect<LLVM::LLVMDialect>();
389 void populateConvertToLLVMConversionPatterns(
390 ConversionTarget &
target, LLVMTypeConverter &converter,
391 RewritePatternSet &patterns)
const final {
404 converter.addTypeAttributeConversion(
405 [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
406 -> TypeConverter::AttributeConversionResult {
407 if (type.getMemorySpace() != memorySpace)
408 return TypeConverter::AttributeConversionResult::na();
409 return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
413 converter.addConversion([&](ptr::PtrType type) ->
Type {
414 std::optional<Attribute> maybeAttr =
415 converter.convertTypeAttribute(type, type.getMemorySpace());
417 maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
420 return LLVM::LLVMPointerType::get(type.getContext(),
421 memSpace.getValue().getSExtValue());
425 converter.addConversion([&](ptr::PtrMetadataType type) ->
Type {
426 auto mTy = dyn_cast<MemRefType>(type.getType());
429 FailureOr<LLVM::LLVMStructType> res =
431 return failed(res) ?
Type() : res.value();
435 patterns.
add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
436 ToPtrOpConversion, TypeOffsetOpConversion>(converter);
441 dialect->addInterfaces<PtrToLLVMDialectInterface>();
static FailureOr< LLVM::LLVMStructType > createMemRefMetadataType(MemRefType type, const LLVMTypeConverter &typeConverter)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
MLIRContext is the top-level object for a collection of MLIR operations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the convert to LLVM patterns for the ptr dialect.
void registerConvertPtrToLLVMInterface(DialectRegistry ®istry)
Register the convert to LLVM interface for the ptr dialect.
Include the generated interface declarations.