31 matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
32 ConversionPatternRewriter &rewriter)
const override;
38struct GetMetadataOpConversion
42 matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
43 ConversionPatternRewriter &rewriter)
const override;
52 matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter)
const override;
62 matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter)
const override;
69struct TypeOffsetOpConversion
73 matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
74 ConversionPatternRewriter &rewriter)
const override;
83static FailureOr<LLVM::LLVMStructType>
89 if (failed(addressSpace))
93 auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace);
98 if (failed(type.getStridesAndOffset(strides, offset)))
114 elements.push_back(ptrType);
117 if (offset == ShapedType::kDynamic)
118 elements.push_back(indexType);
122 if (dim == ShapedType::kDynamic)
123 elements.push_back(indexType);
127 for (
int64_t stride : strides) {
128 if (stride == ShapedType::kDynamic)
129 elements.push_back(indexType);
131 return LLVM::LLVMStructType::getLiteral(context, elements);
138LogicalResult FromPtrOpConversion::matchAndRewrite(
139 ptr::FromPtrOp op, OpAdaptor adaptor,
140 ConversionPatternRewriter &rewriter)
const {
142 auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
144 return rewriter.notifyMatchFailure(op,
"Expected memref result type");
146 if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
147 return rewriter.notifyMatchFailure(
148 op,
"Can convert only memrefs with metadata");
152 Type descriptorTy = getTypeConverter()->convertType(mTy);
154 return rewriter.notifyMatchFailure(op,
"Failed to convert result type");
157 SmallVector<int64_t> strides;
159 if (
failed(mTy.getStridesAndOffset(strides, offset))) {
160 return rewriter.notifyMatchFailure(op,
161 "Failed to get the strides and offset");
163 ArrayRef<int64_t> shape = mTy.getShape();
166 Location loc = op.getLoc();
167 auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
170 desc.setAllocatedPtr(
172 LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getMetadata(), 0));
173 desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
176 unsigned fieldIdx = 1;
179 if (offset == ShapedType::kDynamic) {
180 Value offsetValue = LLVM::ExtractValueOp::create(
181 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
182 desc.setOffset(rewriter, loc, offsetValue);
184 desc.setConstantOffset(rewriter, loc, offset);
188 for (
auto [i, dim] : llvm::enumerate(shape)) {
189 if (dim == ShapedType::kDynamic) {
190 Value sizeValue = LLVM::ExtractValueOp::create(
191 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
192 desc.setSize(rewriter, loc, i, sizeValue);
194 desc.setConstantSize(rewriter, loc, i, dim);
199 for (
auto [i, stride] : llvm::enumerate(strides)) {
200 if (stride == ShapedType::kDynamic) {
201 Value strideValue = LLVM::ExtractValueOp::create(
202 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
203 desc.setStride(rewriter, loc, i, strideValue);
205 desc.setConstantStride(rewriter, loc, i, stride);
209 rewriter.replaceOp(op,
static_cast<Value
>(desc));
217LogicalResult GetMetadataOpConversion::matchAndRewrite(
218 ptr::GetMetadataOp op, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter)
const {
220 auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
222 return rewriter.notifyMatchFailure(op,
"Only memref metadata is supported");
225 FailureOr<LLVM::LLVMStructType> mdTy =
228 return rewriter.notifyMatchFailure(op,
229 "Failed to create the metadata type");
233 MemRefDescriptor descriptor(adaptor.getPtr());
236 SmallVector<int64_t> strides;
238 if (
failed(mTy.getStridesAndOffset(strides, offset))) {
239 return rewriter.notifyMatchFailure(op,
240 "Failed to get the strides and offset");
242 ArrayRef<int64_t> shape = mTy.getShape();
245 Location loc = op.getLoc();
246 Value sV = LLVM::UndefOp::create(rewriter, loc, *mdTy);
249 SmallVector<int64_t> pos{0};
250 sV = LLVM::InsertValueOp::create(rewriter, loc, sV,
251 descriptor.allocatedPtr(rewriter, loc), pos);
254 unsigned fieldIdx = 1;
257 if (offset == ShapedType::kDynamic) {
258 sV = LLVM::InsertValueOp::create(
259 rewriter, loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
263 for (
auto [i, dim] : llvm::enumerate(shape)) {
264 if (dim != ShapedType::kDynamic)
266 sV = LLVM::InsertValueOp::create(
267 rewriter, loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
271 for (
auto [i, stride] : llvm::enumerate(strides)) {
272 if (stride != ShapedType::kDynamic)
274 sV = LLVM::InsertValueOp::create(
275 rewriter, loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
277 rewriter.replaceOp(op, sV);
286PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter)
const {
289 Value base = adaptor.getBase();
290 if (!isa<LLVM::LLVMPointerType>(base.
getType()))
291 return rewriter.notifyMatchFailure(op,
"Incompatible pointer type");
294 Value offset = adaptor.getOffset();
297 Type elementType = IntegerType::get(rewriter.getContext(), 8);
300 LLVM::GEPNoWrapFlags flags;
301 switch (op.getFlags()) {
302 case ptr::PtrAddFlags::none:
303 flags = LLVM::GEPNoWrapFlags::none;
305 case ptr::PtrAddFlags::nusw:
306 flags = LLVM::GEPNoWrapFlags::nusw;
308 case ptr::PtrAddFlags::nuw:
309 flags = LLVM::GEPNoWrapFlags::nuw;
311 case ptr::PtrAddFlags::inbounds:
312 flags = LLVM::GEPNoWrapFlags::inbounds;
317 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.
getType(), elementType,
327ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter)
const {
330 if (!isa<MemRefType>(op.getPtr().getType()))
331 return rewriter.notifyMatchFailure(op,
"Expected a memref input");
335 op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc()));
343LogicalResult TypeOffsetOpConversion::matchAndRewrite(
344 ptr::TypeOffsetOp op, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter)
const {
347 Type type = getTypeConverter()->convertType(op.getElementType());
349 return rewriter.notifyMatchFailure(op,
"Couldn't convert the type");
352 Type rTy = getTypeConverter()->convertType(op.getResult().getType());
354 return rewriter.notifyMatchFailure(op,
"Couldn't convert the result type");
360 auto ptrTy = LLVM::LLVMPointerType::get(
getContext());
364 LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
365 LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
366 ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)}));
369 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes());
379struct PtrToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
381 void loadDependentDialects(MLIRContext *context)
const final {
382 context->loadDialect<LLVM::LLVMDialect>();
387 void populateConvertToLLVMConversionPatterns(
388 ConversionTarget &
target, LLVMTypeConverter &converter,
389 RewritePatternSet &
patterns)
const final {
402 converter.addTypeAttributeConversion(
403 [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
404 -> TypeConverter::AttributeConversionResult {
405 if (type.getMemorySpace() != memorySpace)
406 return TypeConverter::AttributeConversionResult::na();
407 return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
411 converter.addConversion([&](ptr::PtrType type) ->
Type {
412 std::optional<Attribute> maybeAttr =
413 converter.convertTypeAttribute(type, type.getMemorySpace());
415 maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
418 return LLVM::LLVMPointerType::get(type.getContext(),
419 memSpace.getValue().getSExtValue());
423 converter.addConversion([&](ptr::PtrMetadataType type) ->
Type {
424 auto mTy = dyn_cast<MemRefType>(type.getType());
427 FailureOr<LLVM::LLVMStructType> res =
429 return failed(res) ?
Type() : res.value();
433 patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
434 ToPtrOpConversion, TypeOffsetOpConversion>(converter);
439 dialect->addInterfaces<PtrToLLVMDialectInterface>();
static FailureOr< LLVM::LLVMStructType > createMemRefMetadataType(MemRefType type, const LLVMTypeConverter &typeConverter)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the convert to LLVM patterns for the ptr dialect.
void registerConvertPtrToLLVMInterface(DialectRegistry ®istry)
Register the convert to LLVM interface for the ptr dialect.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns