MLIR 22.0.0git
PtrToLLVM.cpp
Go to the documentation of this file.
1//===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include <type_traits>
21
22using namespace mlir;
23
24namespace {
25//===----------------------------------------------------------------------===//
26// FromPtrOpConversion
27//===----------------------------------------------------------------------===//
28struct FromPtrOpConversion : public ConvertOpToLLVMPattern<ptr::FromPtrOp> {
30 LogicalResult
31 matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
32 ConversionPatternRewriter &rewriter) const override;
33};
34
35//===----------------------------------------------------------------------===//
36// GetMetadataOpConversion
37//===----------------------------------------------------------------------===//
38struct GetMetadataOpConversion
39 : public ConvertOpToLLVMPattern<ptr::GetMetadataOp> {
41 LogicalResult
42 matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
43 ConversionPatternRewriter &rewriter) const override;
44};
45
46//===----------------------------------------------------------------------===//
47// PtrAddOpConversion
48//===----------------------------------------------------------------------===//
49struct PtrAddOpConversion : public ConvertOpToLLVMPattern<ptr::PtrAddOp> {
51 LogicalResult
52 matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override;
54};
55
56//===----------------------------------------------------------------------===//
57// ToPtrOpConversion
58//===----------------------------------------------------------------------===//
59struct ToPtrOpConversion : public ConvertOpToLLVMPattern<ptr::ToPtrOp> {
61 LogicalResult
62 matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter) const override;
64};
65
66//===----------------------------------------------------------------------===//
67// TypeOffsetOpConversion
68//===----------------------------------------------------------------------===//
69struct TypeOffsetOpConversion
70 : public ConvertOpToLLVMPattern<ptr::TypeOffsetOp> {
72 LogicalResult
73 matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
74 ConversionPatternRewriter &rewriter) const override;
75};
76} // namespace
77
78//===----------------------------------------------------------------------===//
79// Internal functions
80//===----------------------------------------------------------------------===//
81
82// Function to create an LLVM struct type representing a memref metadata.
83static FailureOr<LLVM::LLVMStructType>
85 const LLVMTypeConverter &typeConverter) {
86 MLIRContext *context = type.getContext();
87 // Get the address space.
88 FailureOr<unsigned> addressSpace = typeConverter.getMemRefAddressSpace(type);
89 if (failed(addressSpace))
90 return failure();
91
92 // Get pointer type (using address space 0 by default)
93 auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace);
94
95 // Get the strides offsets and shape.
97 int64_t offset;
98 if (failed(type.getStridesAndOffset(strides, offset)))
99 return failure();
100 ArrayRef<int64_t> shape = type.getShape();
101
102 // Use index type from the type converter for the descriptor elements
103 Type indexType = typeConverter.getIndexType();
104
105 // For a ranked memref, the descriptor contains:
106 // 1. The pointer to the allocated data
107 // 2. The pointer to the aligned data
108 // 3. The dynamic offset?
109 // 4. The dynamic sizes?
110 // 5. The dynamic strides?
111 SmallVector<Type, 5> elements;
112
113 // Allocated pointer.
114 elements.push_back(ptrType);
115
116 // Potentially add the dynamic offset.
117 if (offset == ShapedType::kDynamic)
118 elements.push_back(indexType);
119
120 // Potentially add the dynamic sizes.
121 for (int64_t dim : shape) {
122 if (dim == ShapedType::kDynamic)
123 elements.push_back(indexType);
124 }
125
126 // Potentially add the dynamic strides.
127 for (int64_t stride : strides) {
128 if (stride == ShapedType::kDynamic)
129 elements.push_back(indexType);
130 }
131 return LLVM::LLVMStructType::getLiteral(context, elements);
132}
133
134//===----------------------------------------------------------------------===//
135// FromPtrOpConversion
136//===----------------------------------------------------------------------===//
137
138LogicalResult FromPtrOpConversion::matchAndRewrite(
139 ptr::FromPtrOp op, OpAdaptor adaptor,
140 ConversionPatternRewriter &rewriter) const {
141 // Get the target memref type
142 auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
143 if (!mTy)
144 return rewriter.notifyMatchFailure(op, "Expected memref result type");
145
146 if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
147 return rewriter.notifyMatchFailure(
148 op, "Can convert only memrefs with metadata");
149 }
150
151 // Convert the result type
152 Type descriptorTy = getTypeConverter()->convertType(mTy);
153 if (!descriptorTy)
154 return rewriter.notifyMatchFailure(op, "Failed to convert result type");
155
156 // Get the strides, offsets and shape.
157 SmallVector<int64_t> strides;
158 int64_t offset;
159 if (failed(mTy.getStridesAndOffset(strides, offset))) {
160 return rewriter.notifyMatchFailure(op,
161 "Failed to get the strides and offset");
162 }
163 ArrayRef<int64_t> shape = mTy.getShape();
164
165 // Create a new memref descriptor
166 Location loc = op.getLoc();
167 auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
168
169 // Set the allocated and aligned pointers.
170 desc.setAllocatedPtr(
171 rewriter, loc,
172 LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getMetadata(), 0));
173 desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
174
175 // Extract metadata from the passed struct.
176 unsigned fieldIdx = 1;
177
178 // Set dynamic offset if needed.
179 if (offset == ShapedType::kDynamic) {
180 Value offsetValue = LLVM::ExtractValueOp::create(
181 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
182 desc.setOffset(rewriter, loc, offsetValue);
183 } else {
184 desc.setConstantOffset(rewriter, loc, offset);
185 }
186
187 // Set dynamic sizes if needed.
188 for (auto [i, dim] : llvm::enumerate(shape)) {
189 if (dim == ShapedType::kDynamic) {
190 Value sizeValue = LLVM::ExtractValueOp::create(
191 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
192 desc.setSize(rewriter, loc, i, sizeValue);
193 } else {
194 desc.setConstantSize(rewriter, loc, i, dim);
195 }
196 }
197
198 // Set dynamic strides if needed.
199 for (auto [i, stride] : llvm::enumerate(strides)) {
200 if (stride == ShapedType::kDynamic) {
201 Value strideValue = LLVM::ExtractValueOp::create(
202 rewriter, loc, adaptor.getMetadata(), fieldIdx++);
203 desc.setStride(rewriter, loc, i, strideValue);
204 } else {
205 desc.setConstantStride(rewriter, loc, i, stride);
206 }
207 }
208
209 rewriter.replaceOp(op, static_cast<Value>(desc));
210 return success();
211}
212
213//===----------------------------------------------------------------------===//
214// GetMetadataOpConversion
215//===----------------------------------------------------------------------===//
216
217LogicalResult GetMetadataOpConversion::matchAndRewrite(
218 ptr::GetMetadataOp op, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter) const {
220 auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
221 if (!mTy)
222 return rewriter.notifyMatchFailure(op, "Only memref metadata is supported");
223
224 // Get the metadata type.
225 FailureOr<LLVM::LLVMStructType> mdTy =
226 createMemRefMetadataType(mTy, *getTypeConverter());
227 if (failed(mdTy)) {
228 return rewriter.notifyMatchFailure(op,
229 "Failed to create the metadata type");
230 }
231
232 // Get the memref descriptor.
233 MemRefDescriptor descriptor(adaptor.getPtr());
234
235 // Get the strides offsets and shape.
236 SmallVector<int64_t> strides;
237 int64_t offset;
238 if (failed(mTy.getStridesAndOffset(strides, offset))) {
239 return rewriter.notifyMatchFailure(op,
240 "Failed to get the strides and offset");
241 }
242 ArrayRef<int64_t> shape = mTy.getShape();
243
244 // Create a new LLVM struct to hold the metadata
245 Location loc = op.getLoc();
246 Value sV = LLVM::UndefOp::create(rewriter, loc, *mdTy);
247
248 // First element is the allocated pointer.
249 SmallVector<int64_t> pos{0};
250 sV = LLVM::InsertValueOp::create(rewriter, loc, sV,
251 descriptor.allocatedPtr(rewriter, loc), pos);
252
253 // Track the current field index.
254 unsigned fieldIdx = 1;
255
256 // Add dynamic offset if needed.
257 if (offset == ShapedType::kDynamic) {
258 sV = LLVM::InsertValueOp::create(
259 rewriter, loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
260 }
261
262 // Add dynamic sizes if needed.
263 for (auto [i, dim] : llvm::enumerate(shape)) {
264 if (dim != ShapedType::kDynamic)
265 continue;
266 sV = LLVM::InsertValueOp::create(
267 rewriter, loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
268 }
269
270 // Add dynamic strides if needed
271 for (auto [i, stride] : llvm::enumerate(strides)) {
272 if (stride != ShapedType::kDynamic)
273 continue;
274 sV = LLVM::InsertValueOp::create(
275 rewriter, loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
276 }
277 rewriter.replaceOp(op, sV);
278 return success();
279}
280
281//===----------------------------------------------------------------------===//
282// PtrAddOpConversion
283//===----------------------------------------------------------------------===//
284
285LogicalResult
286PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter) const {
288 // Get and check the base.
289 Value base = adaptor.getBase();
290 if (!isa<LLVM::LLVMPointerType>(base.getType()))
291 return rewriter.notifyMatchFailure(op, "Incompatible pointer type");
292
293 // Get the offset.
294 Value offset = adaptor.getOffset();
295
296 // Ptr assumes the offset is in bytes.
297 Type elementType = IntegerType::get(rewriter.getContext(), 8);
298
299 // Convert the `ptradd` flags.
300 LLVM::GEPNoWrapFlags flags;
301 switch (op.getFlags()) {
302 case ptr::PtrAddFlags::none:
303 flags = LLVM::GEPNoWrapFlags::none;
304 break;
305 case ptr::PtrAddFlags::nusw:
306 flags = LLVM::GEPNoWrapFlags::nusw;
307 break;
308 case ptr::PtrAddFlags::nuw:
309 flags = LLVM::GEPNoWrapFlags::nuw;
310 break;
311 case ptr::PtrAddFlags::inbounds:
312 flags = LLVM::GEPNoWrapFlags::inbounds;
313 break;
314 }
315
316 // Create the GEP operation with appropriate arguments
317 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.getType(), elementType,
318 base, ValueRange{offset}, flags);
319 return success();
320}
321
322//===----------------------------------------------------------------------===//
323// ToPtrOpConversion
324//===----------------------------------------------------------------------===//
325
326LogicalResult
327ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter) const {
329 // Bail if it's not a memref.
330 if (!isa<MemRefType>(op.getPtr().getType()))
331 return rewriter.notifyMatchFailure(op, "Expected a memref input");
332
333 // Extract the aligned pointer from the memref descriptor.
334 rewriter.replaceOp(
335 op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc()));
336 return success();
337}
338
339//===----------------------------------------------------------------------===//
340// TypeOffsetOpConversion
341//===----------------------------------------------------------------------===//
342
343LogicalResult TypeOffsetOpConversion::matchAndRewrite(
344 ptr::TypeOffsetOp op, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter) const {
346 // Convert the type attribute.
347 Type type = getTypeConverter()->convertType(op.getElementType());
348 if (!type)
349 return rewriter.notifyMatchFailure(op, "Couldn't convert the type");
350
351 // Convert the result type.
352 Type rTy = getTypeConverter()->convertType(op.getResult().getType());
353 if (!rTy)
354 return rewriter.notifyMatchFailure(op, "Couldn't convert the result type");
355
356 // TODO: Use MLIR's data layout. We don't use it because overall support is
357 // still flaky.
358
359 // Create an LLVM pointer type for the GEP operation.
360 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
361
362 // Create a GEP operation to compute the offset of the type.
363 auto offset =
364 LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
365 LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
366 ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)}));
367
368 // Replace the original op with a PtrToIntOp using the computed offset.
369 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes());
370 return success();
371}
372
373//===----------------------------------------------------------------------===//
374// ConvertToLLVMPatternInterface implementation
375//===----------------------------------------------------------------------===//
376
377namespace {
378/// Implement the interface to convert Ptr to LLVM.
379struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
381 void loadDependentDialects(MLIRContext *context) const final {
382 context->loadDialect<LLVM::LLVMDialect>();
383 }
384
385 /// Hook for derived dialect interface to provide conversion patterns
386 /// and mark dialect legal for the conversion target.
387 void populateConvertToLLVMConversionPatterns(
388 ConversionTarget &target, LLVMTypeConverter &converter,
389 RewritePatternSet &patterns) const final {
391 }
392};
393} // namespace
394
395//===----------------------------------------------------------------------===//
396// API
397//===----------------------------------------------------------------------===//
398
401 // Add address space conversions.
402 converter.addTypeAttributeConversion(
403 [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
404 -> TypeConverter::AttributeConversionResult {
405 if (type.getMemorySpace() != memorySpace)
406 return TypeConverter::AttributeConversionResult::na();
407 return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
408 });
409
410 // Add type conversions.
411 converter.addConversion([&](ptr::PtrType type) -> Type {
412 std::optional<Attribute> maybeAttr =
413 converter.convertTypeAttribute(type, type.getMemorySpace());
414 auto memSpace =
415 maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
416 if (!memSpace)
417 return {};
418 return LLVM::LLVMPointerType::get(type.getContext(),
419 memSpace.getValue().getSExtValue());
420 });
421
422 // Convert ptr metadata of memref type.
423 converter.addConversion([&](ptr::PtrMetadataType type) -> Type {
424 auto mTy = dyn_cast<MemRefType>(type.getType());
425 if (!mTy)
426 return {};
427 FailureOr<LLVM::LLVMStructType> res =
428 createMemRefMetadataType(mTy, converter);
429 return failed(res) ? Type() : res.value();
430 });
431
432 // Add conversion patterns.
433 patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
434 ToPtrOpConversion, TypeOffsetOpConversion>(converter);
435}
436
438 registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
439 dialect->addInterfaces<PtrToLLVMDialectInterface>();
440 });
441}
return success()
b getContext())
static FailureOr< LLVM::LLVMStructType > createMemRefMetadataType(MemRefType type, const LLVMTypeConverter &typeConverter)
Definition PtrToLLVM.cpp:84
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type getType() const
Return the type of this value.
Definition Value.h:105
void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the convert to LLVM patterns for the ptr dialect.
void registerConvertPtrToLLVMInterface(DialectRegistry &registry)
Register the convert to LLVM interface for the ptr dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns