20 #include <type_traits>
31 matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
38 struct GetMetadataOpConversion
42 matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
52 matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
62 matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
69 struct TypeOffsetOpConversion
73 matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
83 static FailureOr<LLVM::LLVMStructType>
98 if (
failed(type.getStridesAndOffset(strides, offset)))
114 elements.push_back(ptrType);
117 if (offset == ShapedType::kDynamic)
118 elements.push_back(indexType);
121 for (int64_t dim : shape) {
122 if (dim == ShapedType::kDynamic)
123 elements.push_back(indexType);
127 for (int64_t stride : strides) {
128 if (stride == ShapedType::kDynamic)
129 elements.push_back(indexType);
131 return LLVM::LLVMStructType::getLiteral(context, elements);
138 LogicalResult FromPtrOpConversion::matchAndRewrite(
139 ptr::FromPtrOp op, OpAdaptor adaptor,
142 auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
146 if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
148 op,
"Can convert only memrefs with metadata");
152 Type descriptorTy = getTypeConverter()->convertType(mTy);
159 if (
failed(mTy.getStridesAndOffset(strides, offset))) {
161 "Failed to get the strides and offset");
170 desc.setAllocatedPtr(
172 rewriter.
create<LLVM::ExtractValueOp>(loc, adaptor.getMetadata(), 0));
173 desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
176 unsigned fieldIdx = 1;
179 if (offset == ShapedType::kDynamic) {
180 Value offsetValue = rewriter.
create<LLVM::ExtractValueOp>(
181 loc, adaptor.getMetadata(), fieldIdx++);
182 desc.setOffset(rewriter, loc, offsetValue);
184 desc.setConstantOffset(rewriter, loc, offset);
189 if (dim == ShapedType::kDynamic) {
190 Value sizeValue = rewriter.
create<LLVM::ExtractValueOp>(
191 loc, adaptor.getMetadata(), fieldIdx++);
192 desc.setSize(rewriter, loc, i, sizeValue);
194 desc.setConstantSize(rewriter, loc, i, dim);
200 if (stride == ShapedType::kDynamic) {
201 Value strideValue = rewriter.
create<LLVM::ExtractValueOp>(
202 loc, adaptor.getMetadata(), fieldIdx++);
203 desc.setStride(rewriter, loc, i, strideValue);
205 desc.setConstantStride(rewriter, loc, i, stride);
217 LogicalResult GetMetadataOpConversion::matchAndRewrite(
218 ptr::GetMetadataOp op, OpAdaptor adaptor,
220 auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
225 FailureOr<LLVM::LLVMStructType> mdTy =
229 "Failed to create the metadata type");
238 if (
failed(mTy.getStridesAndOffset(strides, offset))) {
240 "Failed to get the strides and offset");
246 Value sV = rewriter.
create<LLVM::UndefOp>(loc, *mdTy);
249 sV = rewriter.
create<LLVM::InsertValueOp>(
250 loc, sV, descriptor.allocatedPtr(rewriter, loc), 0);
253 unsigned fieldIdx = 1;
256 if (offset == ShapedType::kDynamic) {
257 sV = rewriter.
create<LLVM::InsertValueOp>(
258 loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
263 if (dim != ShapedType::kDynamic)
265 sV = rewriter.
create<LLVM::InsertValueOp>(
266 loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
271 if (stride != ShapedType::kDynamic)
273 sV = rewriter.
create<LLVM::InsertValueOp>(
274 loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
285 PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
288 Value base = adaptor.getBase();
289 if (!isa<LLVM::LLVMPointerType>(base.
getType()))
293 Value offset = adaptor.getOffset();
299 LLVM::GEPNoWrapFlags flags;
300 switch (op.getFlags()) {
301 case ptr::PtrAddFlags::none:
302 flags = LLVM::GEPNoWrapFlags::none;
304 case ptr::PtrAddFlags::nusw:
305 flags = LLVM::GEPNoWrapFlags::nusw;
307 case ptr::PtrAddFlags::nuw:
308 flags = LLVM::GEPNoWrapFlags::nuw;
310 case ptr::PtrAddFlags::inbounds:
311 flags = LLVM::GEPNoWrapFlags::inbounds;
326 ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
329 if (!isa<MemRefType>(op.getPtr().getType()))
342 LogicalResult TypeOffsetOpConversion::matchAndRewrite(
343 ptr::TypeOffsetOp op, OpAdaptor adaptor,
346 Type type = getTypeConverter()->convertType(op.getElementType());
351 Type rTy = getTypeConverter()->convertType(op.getResult().getType());
363 LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
364 LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
380 void loadDependentDialects(
MLIRContext *context)
const final {
381 context->loadDialect<LLVM::LLVMDialect>();
386 void populateConvertToLLVMConversionPatterns(
402 [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
404 if (type.getMemorySpace() != memorySpace)
411 std::optional<Attribute> maybeAttr =
414 maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
418 memSpace.getValue().getSExtValue());
423 auto mTy = dyn_cast<MemRefType>(type.getType());
426 FailureOr<LLVM::LLVMStructType> res =
432 patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
433 ToPtrOpConversion, TypeOffsetOpConversion>(converter);
438 dialect->addInterfaces<PtrToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
static FailureOr< LLVM::LLVMStructType > createMemRefMetadataType(MemRefType type, const LLVMTypeConverter &typeConverter)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult na()
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the convert to LLVM patterns for the ptr dialect.
void registerConvertPtrToLLVMInterface(DialectRegistry ®istry)
Register the convert to LLVM interface for the ptr dialect.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...