19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Type.h"
23 #include "llvm/IR/Value.h"
31 static llvm::AtomicOrdering
32 translateAtomicOrdering(ptr::AtomicOrdering ordering) {
34 case ptr::AtomicOrdering::not_atomic:
35 return llvm::AtomicOrdering::NotAtomic;
36 case ptr::AtomicOrdering::unordered:
37 return llvm::AtomicOrdering::Unordered;
38 case ptr::AtomicOrdering::monotonic:
39 return llvm::AtomicOrdering::Monotonic;
40 case ptr::AtomicOrdering::acquire:
41 return llvm::AtomicOrdering::Acquire;
42 case ptr::AtomicOrdering::release:
43 return llvm::AtomicOrdering::Release;
44 case ptr::AtomicOrdering::acq_rel:
45 return llvm::AtomicOrdering::AcquireRelease;
46 case ptr::AtomicOrdering::seq_cst:
47 return llvm::AtomicOrdering::SequentiallyConsistent;
49 llvm_unreachable(
"Unknown atomic ordering");
54 translatePtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
55 LLVM::ModuleTranslation &moduleTranslation) {
56 llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
57 llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
59 if (!basePtr || !offset)
60 return ptrAddOp.emitError(
"Failed to lookup operands");
63 llvm::GEPNoWrapFlags gepFlags;
64 switch (ptrAddOp.getFlags()) {
65 case ptr::PtrAddFlags::none:
67 case ptr::PtrAddFlags::nusw:
68 gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap();
70 case ptr::PtrAddFlags::nuw:
71 gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap();
73 case ptr::PtrAddFlags::inbounds:
74 gepFlags = llvm::GEPNoWrapFlags::inBounds();
80 builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset},
"", gepFlags);
82 moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
88 translateLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
89 LLVM::ModuleTranslation &moduleTranslation) {
90 llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
92 return loadOp.emitError(
"Failed to lookup pointer operand");
95 llvm::Type *resultType =
96 moduleTranslation.convertType(loadOp.getValue().getType());
98 return loadOp.emitError(
"Failed to translate result type");
101 llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
102 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
103 resultType, ptr, alignment, loadOp.getVolatile_());
106 loadInst->setAtomic(translateAtomicOrdering(loadOp.getOrdering()));
108 if (loadOp.getSyncscope().has_value()) {
109 llvm::LLVMContext &ctx = builder.getContext();
111 ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
112 loadInst->setSyncScopeID(syncScope);
116 if (loadOp.getNontemporal()) {
117 llvm::MDNode *nontemporalMD =
120 loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
123 if (loadOp.getInvariant()) {
125 loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
128 if (loadOp.getInvariantGroup()) {
129 llvm::MDNode *invariantGroupMD =
131 loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
135 moduleTranslation.mapValue(loadOp.getResult(), loadInst);
141 translateStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
142 LLVM::ModuleTranslation &moduleTranslation) {
143 llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
144 llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
147 return storeOp.emitError(
"Failed to lookup operands");
150 llvm::MaybeAlign alignment(storeOp.getAlignment().value_or(0));
151 llvm::StoreInst *storeInst =
152 builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
155 storeInst->setAtomic(translateAtomicOrdering(storeOp.getOrdering()));
157 if (storeOp.getSyncscope().has_value()) {
158 llvm::LLVMContext &ctx = builder.getContext();
160 ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
161 storeInst->setSyncScopeID(syncScope);
165 if (storeOp.getNontemporal()) {
166 llvm::MDNode *nontemporalMD =
169 storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
172 if (storeOp.getInvariantGroup()) {
173 llvm::MDNode *invariantGroupMD =
175 storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
184 translateTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
185 LLVM::ModuleTranslation &moduleTranslation) {
187 llvm::Type *elementType =
188 moduleTranslation.convertType(typeOffsetOp.getElementType());
190 return typeOffsetOp.emitError(
"Failed to translate the element type");
193 llvm::Type *resultType =
194 moduleTranslation.convertType(typeOffsetOp.getResult().getType());
196 return typeOffsetOp.emitError(
"Failed to translate the result type");
199 llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
200 llvm::Value *offsetPtr =
201 builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
202 llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);
204 moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
210 translateGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
211 LLVM::ModuleTranslation &moduleTranslation) {
212 llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
213 llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
214 llvm::Value *passthrough =
215 moduleTranslation.lookupValue(gatherOp.getPassthrough());
217 if (!ptrs || !mask || !passthrough)
218 return gatherOp.emitError(
"Failed to lookup operands");
221 llvm::Type *resultType =
222 moduleTranslation.convertType(gatherOp.getResult().getType());
224 return gatherOp.emitError(
"Failed to translate result type");
227 llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
230 llvm::Value *result = builder.CreateMaskedGather(
231 resultType, ptrs, alignment.valueOrOne(), mask, passthrough);
233 moduleTranslation.mapValue(gatherOp.getResult(), result);
239 translateMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation) {
241 llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
242 llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
243 llvm::Value *passthrough =
244 moduleTranslation.lookupValue(maskedLoadOp.getPassthrough());
246 if (!ptr || !mask || !passthrough)
247 return maskedLoadOp.emitError(
"Failed to lookup operands");
250 llvm::Type *resultType =
251 moduleTranslation.convertType(maskedLoadOp.getResult().getType());
253 return maskedLoadOp.emitError(
"Failed to translate result type");
256 llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
259 llvm::Value *result = builder.CreateMaskedLoad(
260 resultType, ptr, alignment.valueOrOne(), mask, passthrough);
262 moduleTranslation.mapValue(maskedLoadOp.getResult(), result);
268 translateMaskedStoreOp(MaskedStoreOp maskedStoreOp,
269 llvm::IRBuilderBase &builder,
270 LLVM::ModuleTranslation &moduleTranslation) {
271 llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
272 llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
273 llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
275 if (!value || !ptr || !mask)
276 return maskedStoreOp.emitError(
"Failed to lookup operands");
279 llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0));
282 builder.CreateMaskedStore(value, ptr, alignment.valueOrOne(), mask);
288 translateScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
289 LLVM::ModuleTranslation &moduleTranslation) {
290 llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
291 llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
292 llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
294 if (!value || !ptrs || !mask)
295 return scatterOp.emitError(
"Failed to lookup operands");
298 llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0));
301 builder.CreateMaskedScatter(value, ptrs, alignment.valueOrOne(), mask);
307 translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
308 LLVM::ModuleTranslation &moduleTranslation) {
310 llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>(
311 moduleTranslation.convertType(constantOp.getResult().getType()));
313 return constantOp.emitError(
"Expected a valid pointer type");
315 llvm::Value *result =
nullptr;
317 TypedAttr value = constantOp.getValue();
318 if (
auto nullAttr = dyn_cast<ptr::NullAttr>(value)) {
321 }
else if (
auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) {
323 llvm::APInt addressValue = addressAttr.getValue();
326 llvm::DataLayout dataLayout =
327 moduleTranslation.getLLVMModule()->getDataLayout();
328 unsigned pointerSizeInBits =
329 dataLayout.getPointerSizeInBits(resultType->getAddressSpace());
332 if (addressValue.getBitWidth() != pointerSizeInBits) {
333 if (addressValue.getBitWidth() > pointerSizeInBits) {
334 constantOp.emitWarning()
335 <<
"Truncating address value to fit pointer size";
337 addressValue = addressValue.getBitWidth() < pointerSizeInBits
338 ? addressValue.zext(pointerSizeInBits)
339 : addressValue.trunc(pointerSizeInBits);
343 llvm::Type *intType = builder.getIntNTy(pointerSizeInBits);
345 result = builder.CreateIntToPtr(intValue, resultType);
347 return constantOp.emitError(
"Unsupported constant attribute type");
350 moduleTranslation.mapValue(constantOp.getResult(), result);
356 translatePtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder,
357 LLVM::ModuleTranslation &moduleTranslation) {
358 llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs());
359 llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs());
362 return ptrDiffOp.emitError(
"Failed to lookup operands");
365 llvm::Type *resultType =
366 moduleTranslation.convertType(ptrDiffOp.getResult().getType());
368 return ptrDiffOp.emitError(
"Failed to translate result type");
370 PtrDiffFlags flags = ptrDiffOp.getFlags();
374 llvm::Value *llLhs = builder.CreatePtrToAddr(lhs);
375 llvm::Value *llRhs = builder.CreatePtrToAddr(rhs);
376 llvm::Value *result = builder.CreateSub(
378 (flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw,
379 (flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw);
383 if (result->getType() != resultType)
384 result = builder.CreateIntCast(result, resultType,
true);
386 moduleTranslation.mapValue(ptrDiffOp.getResult(), result);
392 class PtrDialectLLVMIRTranslationInterface
400 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
401 LLVM::ModuleTranslation &moduleTranslation)
const final {
404 .Case([&](ConstantOp constantOp) {
405 return translateConstantOp(constantOp, builder, moduleTranslation);
407 .Case([&](PtrAddOp ptrAddOp) {
408 return translatePtrAddOp(ptrAddOp, builder, moduleTranslation);
410 .Case([&](PtrDiffOp ptrDiffOp) {
411 return translatePtrDiffOp(ptrDiffOp, builder, moduleTranslation);
413 .Case([&](LoadOp loadOp) {
414 return translateLoadOp(loadOp, builder, moduleTranslation);
416 .Case([&](StoreOp storeOp) {
417 return translateStoreOp(storeOp, builder, moduleTranslation);
419 .Case([&](TypeOffsetOp typeOffsetOp) {
420 return translateTypeOffsetOp(typeOffsetOp, builder,
423 .Case<GatherOp>([&](GatherOp gatherOp) {
424 return translateGatherOp(gatherOp, builder, moduleTranslation);
426 .Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
427 return translateMaskedLoadOp(maskedLoadOp, builder,
430 .Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
431 return translateMaskedStoreOp(maskedStoreOp, builder,
434 .Case<ScatterOp>([&](ScatterOp scatterOp) {
435 return translateScatterOp(scatterOp, builder, moduleTranslation);
438 return op->
emitError(
"Translation for operation '")
439 << op->
getName() <<
"' is not implemented.";
447 LLVM::ModuleTranslation &moduleTranslation)
const final {
455 registry.
insert<ptr::PtrDialect>();
457 dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>();
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Include the generated interface declarations.
void registerPtrDialectTranslation(DialectRegistry ®istry)
Register the ptr dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...