MLIR 22.0.0git
PtrToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- PtrToLLVMIRTranslation.cpp - Translate `ptr` to LLVM IR ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR `ptr` dialect and
10// LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
17#include "mlir/IR/Operation.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/Type.h"
23#include "llvm/IR/Value.h"
24
25using namespace mlir;
26using namespace mlir::ptr;
27
28namespace {
29
30/// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
31static llvm::AtomicOrdering
32translateAtomicOrdering(ptr::AtomicOrdering ordering) {
33 switch (ordering) {
34 case ptr::AtomicOrdering::not_atomic:
35 return llvm::AtomicOrdering::NotAtomic;
36 case ptr::AtomicOrdering::unordered:
37 return llvm::AtomicOrdering::Unordered;
38 case ptr::AtomicOrdering::monotonic:
39 return llvm::AtomicOrdering::Monotonic;
40 case ptr::AtomicOrdering::acquire:
41 return llvm::AtomicOrdering::Acquire;
42 case ptr::AtomicOrdering::release:
43 return llvm::AtomicOrdering::Release;
44 case ptr::AtomicOrdering::acq_rel:
45 return llvm::AtomicOrdering::AcquireRelease;
46 case ptr::AtomicOrdering::seq_cst:
47 return llvm::AtomicOrdering::SequentiallyConsistent;
48 }
49 llvm_unreachable("Unknown atomic ordering");
50}
51
52/// Translate ptr.ptr_add operation to LLVM IR.
53static LogicalResult
54translatePtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
55 LLVM::ModuleTranslation &moduleTranslation) {
56 llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
57 llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
58
59 if (!basePtr || !offset)
60 return ptrAddOp.emitError("Failed to lookup operands");
61
62 // Create the GEP flags
63 llvm::GEPNoWrapFlags gepFlags;
64 switch (ptrAddOp.getFlags()) {
65 case ptr::PtrAddFlags::none:
66 break;
67 case ptr::PtrAddFlags::nusw:
68 gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap();
69 break;
70 case ptr::PtrAddFlags::nuw:
71 gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap();
72 break;
73 case ptr::PtrAddFlags::inbounds:
74 gepFlags = llvm::GEPNoWrapFlags::inBounds();
75 break;
76 }
77
78 // Create GEP instruction for pointer arithmetic
79 llvm::Value *gep =
80 builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset}, "", gepFlags);
81
82 moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
83 return success();
84}
85
86/// Translate ptr.load operation to LLVM IR.
87static LogicalResult
88translateLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
89 LLVM::ModuleTranslation &moduleTranslation) {
90 llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
91 if (!ptr)
92 return loadOp.emitError("Failed to lookup pointer operand");
93
94 // Translate result type to LLVM type
95 llvm::Type *resultType =
96 moduleTranslation.convertType(loadOp.getValue().getType());
97 if (!resultType)
98 return loadOp.emitError("Failed to translate result type");
99
100 // Create the load instruction.
101 llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
102 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
103 resultType, ptr, alignment, loadOp.getVolatile_());
104
105 // Set op flags and metadata.
106 loadInst->setAtomic(translateAtomicOrdering(loadOp.getOrdering()));
107 // Set sync scope if specified
108 if (loadOp.getSyncscope().has_value()) {
109 llvm::LLVMContext &ctx = builder.getContext();
110 llvm::SyncScope::ID syncScope =
111 ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
112 loadInst->setSyncScopeID(syncScope);
113 }
114
115 // Set metadata for nontemporal, invariant, and invariant_group
116 if (loadOp.getNontemporal()) {
117 llvm::MDNode *nontemporalMD =
118 llvm::MDNode::get(builder.getContext(),
119 llvm::ConstantAsMetadata::get(builder.getInt32(1)));
120 loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
121 }
122
123 if (loadOp.getInvariant()) {
124 llvm::MDNode *invariantMD = llvm::MDNode::get(builder.getContext(), {});
125 loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
126 }
127
128 if (loadOp.getInvariantGroup()) {
129 llvm::MDNode *invariantGroupMD =
130 llvm::MDNode::get(builder.getContext(), {});
131 loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
132 invariantGroupMD);
133 }
134
135 moduleTranslation.mapValue(loadOp.getResult(), loadInst);
136 return success();
137}
138
139/// Translate ptr.store operation to LLVM IR.
140static LogicalResult
141translateStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
142 LLVM::ModuleTranslation &moduleTranslation) {
143 llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
144 llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
145
146 if (!value || !ptr)
147 return storeOp.emitError("Failed to lookup operands");
148
149 // Create the store instruction.
150 llvm::MaybeAlign alignment(storeOp.getAlignment().value_or(0));
151 llvm::StoreInst *storeInst =
152 builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
153
154 // Set op flags and metadata.
155 storeInst->setAtomic(translateAtomicOrdering(storeOp.getOrdering()));
156 // Set sync scope if specified
157 if (storeOp.getSyncscope().has_value()) {
158 llvm::LLVMContext &ctx = builder.getContext();
159 llvm::SyncScope::ID syncScope =
160 ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
161 storeInst->setSyncScopeID(syncScope);
162 }
163
164 // Set metadata for nontemporal and invariant_group
165 if (storeOp.getNontemporal()) {
166 llvm::MDNode *nontemporalMD =
167 llvm::MDNode::get(builder.getContext(),
168 llvm::ConstantAsMetadata::get(builder.getInt32(1)));
169 storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
170 }
171
172 if (storeOp.getInvariantGroup()) {
173 llvm::MDNode *invariantGroupMD =
174 llvm::MDNode::get(builder.getContext(), {});
175 storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
176 invariantGroupMD);
177 }
178
179 return success();
180}
181
182/// Translate ptr.type_offset operation to LLVM IR.
183static LogicalResult
184translateTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
185 LLVM::ModuleTranslation &moduleTranslation) {
186 // Translate the element type to LLVM type
187 llvm::Type *elementType =
188 moduleTranslation.convertType(typeOffsetOp.getElementType());
189 if (!elementType)
190 return typeOffsetOp.emitError("Failed to translate the element type");
191
192 // Translate result type
193 llvm::Type *resultType =
194 moduleTranslation.convertType(typeOffsetOp.getResult().getType());
195 if (!resultType)
196 return typeOffsetOp.emitError("Failed to translate the result type");
197
198 // Use GEP with null pointer to compute type size/offset.
199 llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
200 llvm::Value *offsetPtr =
201 builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
202 llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);
203
204 moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
205 return success();
206}
207
208/// Translate ptr.gather operation to LLVM IR.
209static LogicalResult
210translateGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
211 LLVM::ModuleTranslation &moduleTranslation) {
212 llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
213 llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
214 llvm::Value *passthrough =
215 moduleTranslation.lookupValue(gatherOp.getPassthrough());
216
217 if (!ptrs || !mask || !passthrough)
218 return gatherOp.emitError("Failed to lookup operands");
219
220 // Translate result type to LLVM type.
221 llvm::Type *resultType =
222 moduleTranslation.convertType(gatherOp.getResult().getType());
223 if (!resultType)
224 return gatherOp.emitError("Failed to translate result type");
225
226 // Get the alignment.
227 llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
228
229 // Create the masked gather intrinsic call.
230 llvm::Value *result = builder.CreateMaskedGather(
231 resultType, ptrs, alignment.valueOrOne(), mask, passthrough);
232
233 moduleTranslation.mapValue(gatherOp.getResult(), result);
234 return success();
235}
236
237/// Translate ptr.masked_load operation to LLVM IR.
238static LogicalResult
239translateMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
240 LLVM::ModuleTranslation &moduleTranslation) {
241 llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
242 llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
243 llvm::Value *passthrough =
244 moduleTranslation.lookupValue(maskedLoadOp.getPassthrough());
245
246 if (!ptr || !mask || !passthrough)
247 return maskedLoadOp.emitError("Failed to lookup operands");
248
249 // Translate result type to LLVM type.
250 llvm::Type *resultType =
251 moduleTranslation.convertType(maskedLoadOp.getResult().getType());
252 if (!resultType)
253 return maskedLoadOp.emitError("Failed to translate result type");
254
255 // Get the alignment.
256 llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
257
258 // Create the masked load intrinsic call.
259 llvm::Value *result = builder.CreateMaskedLoad(
260 resultType, ptr, alignment.valueOrOne(), mask, passthrough);
261
262 moduleTranslation.mapValue(maskedLoadOp.getResult(), result);
263 return success();
264}
265
266/// Translate ptr.masked_store operation to LLVM IR.
267static LogicalResult
268translateMaskedStoreOp(MaskedStoreOp maskedStoreOp,
269 llvm::IRBuilderBase &builder,
270 LLVM::ModuleTranslation &moduleTranslation) {
271 llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
272 llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
273 llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
274
275 if (!value || !ptr || !mask)
276 return maskedStoreOp.emitError("Failed to lookup operands");
277
278 // Get the alignment.
279 llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0));
280
281 // Create the masked store intrinsic call.
282 builder.CreateMaskedStore(value, ptr, alignment.valueOrOne(), mask);
283 return success();
284}
285
286/// Translate ptr.scatter operation to LLVM IR.
287static LogicalResult
288translateScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
289 LLVM::ModuleTranslation &moduleTranslation) {
290 llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
291 llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
292 llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
293
294 if (!value || !ptrs || !mask)
295 return scatterOp.emitError("Failed to lookup operands");
296
297 // Get the alignment.
298 llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0));
299
300 // Create the masked scatter intrinsic call.
301 builder.CreateMaskedScatter(value, ptrs, alignment.valueOrOne(), mask);
302 return success();
303}
304
305/// Translate ptr.constant operation to LLVM IR.
306static LogicalResult
307translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
308 LLVM::ModuleTranslation &moduleTranslation) {
309 // Translate result type to LLVM type
310 llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>(
311 moduleTranslation.convertType(constantOp.getResult().getType()));
312 if (!resultType)
313 return constantOp.emitError("Expected a valid pointer type");
314
315 llvm::Value *result = nullptr;
316
317 TypedAttr value = constantOp.getValue();
318 if (auto nullAttr = dyn_cast<ptr::NullAttr>(value)) {
319 // Create a null pointer constant
320 result = llvm::ConstantPointerNull::get(resultType);
321 } else if (auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) {
322 // Create an integer constant and translate it to pointer
323 llvm::APInt addressValue = addressAttr.getValue();
324
325 // Determine the integer type width based on the target's pointer size
326 llvm::DataLayout dataLayout =
327 moduleTranslation.getLLVMModule()->getDataLayout();
328 unsigned pointerSizeInBits =
329 dataLayout.getPointerSizeInBits(resultType->getAddressSpace());
330
331 // Extend or truncate the address value to match pointer size if needed
332 if (addressValue.getBitWidth() != pointerSizeInBits) {
333 if (addressValue.getBitWidth() > pointerSizeInBits) {
334 constantOp.emitWarning()
335 << "Truncating address value to fit pointer size";
336 }
337 addressValue = addressValue.getBitWidth() < pointerSizeInBits
338 ? addressValue.zext(pointerSizeInBits)
339 : addressValue.trunc(pointerSizeInBits);
340 }
341
342 // Create integer constant and translate to pointer
343 llvm::Type *intType = builder.getIntNTy(pointerSizeInBits);
344 llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue);
345 result = builder.CreateIntToPtr(intValue, resultType);
346 } else {
347 return constantOp.emitError("Unsupported constant attribute type");
348 }
349
350 moduleTranslation.mapValue(constantOp.getResult(), result);
351 return success();
352}
353
354/// Translate ptr.ptr_diff operation operation to LLVM IR.
355static LogicalResult
356translatePtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder,
357 LLVM::ModuleTranslation &moduleTranslation) {
358 llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs());
359 llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs());
360
361 if (!lhs || !rhs)
362 return ptrDiffOp.emitError("Failed to lookup operands");
363
364 // Translate result type to LLVM type
365 llvm::Type *resultType =
366 moduleTranslation.convertType(ptrDiffOp.getResult().getType());
367 if (!resultType)
368 return ptrDiffOp.emitError("Failed to translate result type");
369
370 PtrDiffFlags flags = ptrDiffOp.getFlags();
371
372 // Convert both pointers to integers using ptrtoaddr, and compute the
373 // difference: lhs - rhs
374 llvm::Value *llLhs = builder.CreatePtrToAddr(lhs);
375 llvm::Value *llRhs = builder.CreatePtrToAddr(rhs);
376 llvm::Value *result = builder.CreateSub(
377 llLhs, llRhs, /*Name=*/"",
378 /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw,
379 /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw);
380
381 // Convert the difference to the expected result type by truncating or
382 // extending.
383 if (result->getType() != resultType)
384 result = builder.CreateIntCast(result, resultType, /*isSigned=*/true);
385
386 moduleTranslation.mapValue(ptrDiffOp.getResult(), result);
387 return success();
388}
389
390/// Implementation of the dialect interface that translates operations belonging
391/// to the `ptr` dialect to LLVM IR.
392class PtrDialectLLVMIRTranslationInterface
394public:
396
397 /// Translates the given operation to LLVM IR using the provided IR builder
398 /// and saving the state in `moduleTranslation`.
399 LogicalResult
400 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
401 LLVM::ModuleTranslation &moduleTranslation) const final {
402
403 return llvm::TypeSwitch<Operation *, LogicalResult>(op)
404 .Case([&](ConstantOp constantOp) {
405 return translateConstantOp(constantOp, builder, moduleTranslation);
406 })
407 .Case([&](PtrAddOp ptrAddOp) {
408 return translatePtrAddOp(ptrAddOp, builder, moduleTranslation);
409 })
410 .Case([&](PtrDiffOp ptrDiffOp) {
411 return translatePtrDiffOp(ptrDiffOp, builder, moduleTranslation);
412 })
413 .Case([&](LoadOp loadOp) {
414 return translateLoadOp(loadOp, builder, moduleTranslation);
415 })
416 .Case([&](StoreOp storeOp) {
417 return translateStoreOp(storeOp, builder, moduleTranslation);
418 })
419 .Case([&](TypeOffsetOp typeOffsetOp) {
420 return translateTypeOffsetOp(typeOffsetOp, builder,
421 moduleTranslation);
422 })
423 .Case<GatherOp>([&](GatherOp gatherOp) {
424 return translateGatherOp(gatherOp, builder, moduleTranslation);
425 })
426 .Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
427 return translateMaskedLoadOp(maskedLoadOp, builder,
428 moduleTranslation);
429 })
430 .Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
431 return translateMaskedStoreOp(maskedStoreOp, builder,
432 moduleTranslation);
433 })
434 .Case<ScatterOp>([&](ScatterOp scatterOp) {
435 return translateScatterOp(scatterOp, builder, moduleTranslation);
436 })
437 .Default([&](Operation *op) {
438 return op->emitError("Translation for operation '")
439 << op->getName() << "' is not implemented.";
440 });
441 }
442
443 /// Attaches module-level metadata for functions marked as kernels.
444 LogicalResult
445 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
446 NamedAttribute attribute,
447 LLVM::ModuleTranslation &moduleTranslation) const final {
448 // No special amendments needed for ptr dialect operations
449 return success();
450 }
451};
452} // namespace
453
455 registry.insert<ptr::PtrDialect>();
456 registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
457 dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>();
458 });
459}
460
462 DialectRegistry registry;
464 context.appendDialectRegistry(registry);
465}
return success()
lhs
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
Include the generated interface declarations.
void registerPtrDialectTranslation(DialectRegistry &registry)
Register the ptr dialect and the translation from it to the LLVM IR in the given registry;.