19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Type.h"
23 #include "llvm/IR/Value.h"
31 static llvm::AtomicOrdering
34 case ptr::AtomicOrdering::not_atomic:
35 return llvm::AtomicOrdering::NotAtomic;
36 case ptr::AtomicOrdering::unordered:
37 return llvm::AtomicOrdering::Unordered;
38 case ptr::AtomicOrdering::monotonic:
39 return llvm::AtomicOrdering::Monotonic;
40 case ptr::AtomicOrdering::acquire:
41 return llvm::AtomicOrdering::Acquire;
42 case ptr::AtomicOrdering::release:
43 return llvm::AtomicOrdering::Release;
44 case ptr::AtomicOrdering::acq_rel:
45 return llvm::AtomicOrdering::AcquireRelease;
46 case ptr::AtomicOrdering::seq_cst:
47 return llvm::AtomicOrdering::SequentiallyConsistent;
49 llvm_unreachable(
"Unknown atomic ordering");
54 convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
55 LLVM::ModuleTranslation &moduleTranslation) {
56 llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
57 llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
59 if (!basePtr || !offset)
60 return ptrAddOp.emitError(
"Failed to lookup operands");
63 llvm::GEPNoWrapFlags gepFlags;
64 switch (ptrAddOp.getFlags()) {
65 case ptr::PtrAddFlags::none:
67 case ptr::PtrAddFlags::nusw:
68 gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap();
70 case ptr::PtrAddFlags::nuw:
71 gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap();
73 case ptr::PtrAddFlags::inbounds:
74 gepFlags = llvm::GEPNoWrapFlags::inBounds();
80 builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset},
"", gepFlags);
82 moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
87 static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
88 LLVM::ModuleTranslation &moduleTranslation) {
89 llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
91 return loadOp.emitError(
"Failed to lookup pointer operand");
94 llvm::Type *resultType =
95 moduleTranslation.convertType(loadOp.getValue().getType());
97 return loadOp.emitError(
"Failed to convert result type");
100 llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
101 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
102 resultType, ptr, alignment, loadOp.getVolatile_());
107 if (loadOp.getSyncscope().has_value()) {
108 llvm::LLVMContext &ctx = builder.getContext();
110 ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
111 loadInst->setSyncScopeID(syncScope);
115 if (loadOp.getNontemporal()) {
116 llvm::MDNode *nontemporalMD =
119 loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
122 if (loadOp.getInvariant()) {
124 loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
127 if (loadOp.getInvariantGroup()) {
128 llvm::MDNode *invariantGroupMD =
130 loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
134 moduleTranslation.mapValue(loadOp.getResult(), loadInst);
140 convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
141 LLVM::ModuleTranslation &moduleTranslation) {
142 llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
143 llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
146 return storeOp.emitError(
"Failed to lookup operands");
149 llvm::MaybeAlign alignment(storeOp.getAlignment().value_or(0));
150 llvm::StoreInst *storeInst =
151 builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
156 if (storeOp.getSyncscope().has_value()) {
157 llvm::LLVMContext &ctx = builder.getContext();
159 ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
160 storeInst->setSyncScopeID(syncScope);
164 if (storeOp.getNontemporal()) {
165 llvm::MDNode *nontemporalMD =
168 storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
171 if (storeOp.getInvariantGroup()) {
172 llvm::MDNode *invariantGroupMD =
174 storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
183 convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
184 LLVM::ModuleTranslation &moduleTranslation) {
186 llvm::Type *elementType =
187 moduleTranslation.convertType(typeOffsetOp.getElementType());
189 return typeOffsetOp.emitError(
"Failed to convert the element type");
192 llvm::Type *resultType =
193 moduleTranslation.convertType(typeOffsetOp.getResult().getType());
195 return typeOffsetOp.emitError(
"Failed to convert the result type");
198 llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
199 llvm::Value *offsetPtr =
200 builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
201 llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);
203 moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
209 convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
210 LLVM::ModuleTranslation &moduleTranslation) {
211 llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
212 llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
213 llvm::Value *passthrough =
214 moduleTranslation.lookupValue(gatherOp.getPassthrough());
216 if (!ptrs || !mask || !passthrough)
217 return gatherOp.emitError(
"Failed to lookup operands");
220 llvm::Type *resultType =
221 moduleTranslation.convertType(gatherOp.getResult().getType());
223 return gatherOp.emitError(
"Failed to convert result type");
226 llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
229 llvm::Value *result = builder.CreateMaskedGather(
230 resultType, ptrs, alignment.valueOrOne(), mask, passthrough);
232 moduleTranslation.mapValue(gatherOp.getResult(), result);
238 convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
239 LLVM::ModuleTranslation &moduleTranslation) {
240 llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
241 llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
242 llvm::Value *passthrough =
243 moduleTranslation.lookupValue(maskedLoadOp.getPassthrough());
245 if (!ptr || !mask || !passthrough)
246 return maskedLoadOp.emitError(
"Failed to lookup operands");
249 llvm::Type *resultType =
250 moduleTranslation.convertType(maskedLoadOp.getResult().getType());
252 return maskedLoadOp.emitError(
"Failed to convert result type");
255 llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
258 llvm::Value *result = builder.CreateMaskedLoad(
259 resultType, ptr, alignment.valueOrOne(), mask, passthrough);
261 moduleTranslation.mapValue(maskedLoadOp.getResult(), result);
267 convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
268 LLVM::ModuleTranslation &moduleTranslation) {
269 llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
270 llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
271 llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
273 if (!value || !ptr || !mask)
274 return maskedStoreOp.emitError(
"Failed to lookup operands");
277 llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0));
280 builder.CreateMaskedStore(value, ptr, alignment.valueOrOne(), mask);
286 convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
287 LLVM::ModuleTranslation &moduleTranslation) {
288 llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
289 llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
290 llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
292 if (!value || !ptrs || !mask)
293 return scatterOp.emitError(
"Failed to lookup operands");
296 llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0));
299 builder.CreateMaskedScatter(value, ptrs, alignment.valueOrOne(), mask);
305 class PtrDialectLLVMIRTranslationInterface
313 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
314 LLVM::ModuleTranslation &moduleTranslation)
const final {
317 .Case([&](PtrAddOp ptrAddOp) {
318 return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
320 .Case([&](LoadOp loadOp) {
321 return convertLoadOp(loadOp, builder, moduleTranslation);
323 .Case([&](StoreOp storeOp) {
324 return convertStoreOp(storeOp, builder, moduleTranslation);
326 .Case([&](TypeOffsetOp typeOffsetOp) {
327 return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
329 .Case<GatherOp>([&](GatherOp gatherOp) {
330 return convertGatherOp(gatherOp, builder, moduleTranslation);
332 .Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
333 return convertMaskedLoadOp(maskedLoadOp, builder, moduleTranslation);
335 .Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
336 return convertMaskedStoreOp(maskedStoreOp, builder,
339 .Case<ScatterOp>([&](ScatterOp scatterOp) {
340 return convertScatterOp(scatterOp, builder, moduleTranslation);
343 return op->
emitError(
"Translation for operation '")
344 << op->
getName() <<
"' is not implemented.";
352 LLVM::ModuleTranslation &moduleTranslation)
const final {
360 registry.
insert<ptr::PtrDialect>();
362 dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>();
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Include the generated interface declarations.
void registerPtrDialectTranslation(DialectRegistry ®istry)
Register the ptr dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...