MLIR  22.0.0git
PtrToLLVMIRTranslation.cpp
Go to the documentation of this file.
1 //===- PtrToLLVMIRTranslation.cpp - Translate `ptr` to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR `ptr` dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Type.h"
23 #include "llvm/IR/Value.h"
24 
25 using namespace mlir;
26 using namespace mlir::ptr;
27 
28 namespace {
29 
30 /// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
31 static llvm::AtomicOrdering
32 translateAtomicOrdering(ptr::AtomicOrdering ordering) {
33  switch (ordering) {
34  case ptr::AtomicOrdering::not_atomic:
35  return llvm::AtomicOrdering::NotAtomic;
36  case ptr::AtomicOrdering::unordered:
37  return llvm::AtomicOrdering::Unordered;
38  case ptr::AtomicOrdering::monotonic:
39  return llvm::AtomicOrdering::Monotonic;
40  case ptr::AtomicOrdering::acquire:
41  return llvm::AtomicOrdering::Acquire;
42  case ptr::AtomicOrdering::release:
43  return llvm::AtomicOrdering::Release;
44  case ptr::AtomicOrdering::acq_rel:
45  return llvm::AtomicOrdering::AcquireRelease;
46  case ptr::AtomicOrdering::seq_cst:
47  return llvm::AtomicOrdering::SequentiallyConsistent;
48  }
49  llvm_unreachable("Unknown atomic ordering");
50 }
51 
52 /// Translate ptr.ptr_add operation to LLVM IR.
53 static LogicalResult
54 translatePtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
55  LLVM::ModuleTranslation &moduleTranslation) {
56  llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
57  llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
58 
59  if (!basePtr || !offset)
60  return ptrAddOp.emitError("Failed to lookup operands");
61 
62  // Create the GEP flags
63  llvm::GEPNoWrapFlags gepFlags;
64  switch (ptrAddOp.getFlags()) {
65  case ptr::PtrAddFlags::none:
66  break;
67  case ptr::PtrAddFlags::nusw:
68  gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap();
69  break;
70  case ptr::PtrAddFlags::nuw:
71  gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap();
72  break;
73  case ptr::PtrAddFlags::inbounds:
74  gepFlags = llvm::GEPNoWrapFlags::inBounds();
75  break;
76  }
77 
78  // Create GEP instruction for pointer arithmetic
79  llvm::Value *gep =
80  builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset}, "", gepFlags);
81 
82  moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
83  return success();
84 }
85 
86 /// Translate ptr.load operation to LLVM IR.
87 static LogicalResult
88 translateLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
89  LLVM::ModuleTranslation &moduleTranslation) {
90  llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
91  if (!ptr)
92  return loadOp.emitError("Failed to lookup pointer operand");
93 
94  // Translate result type to LLVM type
95  llvm::Type *resultType =
96  moduleTranslation.convertType(loadOp.getValue().getType());
97  if (!resultType)
98  return loadOp.emitError("Failed to translate result type");
99 
100  // Create the load instruction.
101  llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
102  llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
103  resultType, ptr, alignment, loadOp.getVolatile_());
104 
105  // Set op flags and metadata.
106  loadInst->setAtomic(translateAtomicOrdering(loadOp.getOrdering()));
107  // Set sync scope if specified
108  if (loadOp.getSyncscope().has_value()) {
109  llvm::LLVMContext &ctx = builder.getContext();
110  llvm::SyncScope::ID syncScope =
111  ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
112  loadInst->setSyncScopeID(syncScope);
113  }
114 
115  // Set metadata for nontemporal, invariant, and invariant_group
116  if (loadOp.getNontemporal()) {
117  llvm::MDNode *nontemporalMD =
118  llvm::MDNode::get(builder.getContext(),
119  llvm::ConstantAsMetadata::get(builder.getInt32(1)));
120  loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
121  }
122 
123  if (loadOp.getInvariant()) {
124  llvm::MDNode *invariantMD = llvm::MDNode::get(builder.getContext(), {});
125  loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
126  }
127 
128  if (loadOp.getInvariantGroup()) {
129  llvm::MDNode *invariantGroupMD =
130  llvm::MDNode::get(builder.getContext(), {});
131  loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
132  invariantGroupMD);
133  }
134 
135  moduleTranslation.mapValue(loadOp.getResult(), loadInst);
136  return success();
137 }
138 
139 /// Translate ptr.store operation to LLVM IR.
140 static LogicalResult
141 translateStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
142  LLVM::ModuleTranslation &moduleTranslation) {
143  llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
144  llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
145 
146  if (!value || !ptr)
147  return storeOp.emitError("Failed to lookup operands");
148 
149  // Create the store instruction.
150  llvm::MaybeAlign alignment(storeOp.getAlignment().value_or(0));
151  llvm::StoreInst *storeInst =
152  builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
153 
154  // Set op flags and metadata.
155  storeInst->setAtomic(translateAtomicOrdering(storeOp.getOrdering()));
156  // Set sync scope if specified
157  if (storeOp.getSyncscope().has_value()) {
158  llvm::LLVMContext &ctx = builder.getContext();
159  llvm::SyncScope::ID syncScope =
160  ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
161  storeInst->setSyncScopeID(syncScope);
162  }
163 
164  // Set metadata for nontemporal and invariant_group
165  if (storeOp.getNontemporal()) {
166  llvm::MDNode *nontemporalMD =
167  llvm::MDNode::get(builder.getContext(),
168  llvm::ConstantAsMetadata::get(builder.getInt32(1)));
169  storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
170  }
171 
172  if (storeOp.getInvariantGroup()) {
173  llvm::MDNode *invariantGroupMD =
174  llvm::MDNode::get(builder.getContext(), {});
175  storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
176  invariantGroupMD);
177  }
178 
179  return success();
180 }
181 
182 /// Translate ptr.type_offset operation to LLVM IR.
183 static LogicalResult
184 translateTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
185  LLVM::ModuleTranslation &moduleTranslation) {
186  // Translate the element type to LLVM type
187  llvm::Type *elementType =
188  moduleTranslation.convertType(typeOffsetOp.getElementType());
189  if (!elementType)
190  return typeOffsetOp.emitError("Failed to translate the element type");
191 
192  // Translate result type
193  llvm::Type *resultType =
194  moduleTranslation.convertType(typeOffsetOp.getResult().getType());
195  if (!resultType)
196  return typeOffsetOp.emitError("Failed to translate the result type");
197 
198  // Use GEP with null pointer to compute type size/offset.
199  llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
200  llvm::Value *offsetPtr =
201  builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
202  llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);
203 
204  moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
205  return success();
206 }
207 
208 /// Translate ptr.gather operation to LLVM IR.
209 static LogicalResult
210 translateGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
211  LLVM::ModuleTranslation &moduleTranslation) {
212  llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
213  llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
214  llvm::Value *passthrough =
215  moduleTranslation.lookupValue(gatherOp.getPassthrough());
216 
217  if (!ptrs || !mask || !passthrough)
218  return gatherOp.emitError("Failed to lookup operands");
219 
220  // Translate result type to LLVM type.
221  llvm::Type *resultType =
222  moduleTranslation.convertType(gatherOp.getResult().getType());
223  if (!resultType)
224  return gatherOp.emitError("Failed to translate result type");
225 
226  // Get the alignment.
227  llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
228 
229  // Create the masked gather intrinsic call.
230  llvm::Value *result = builder.CreateMaskedGather(
231  resultType, ptrs, alignment.valueOrOne(), mask, passthrough);
232 
233  moduleTranslation.mapValue(gatherOp.getResult(), result);
234  return success();
235 }
236 
237 /// Translate ptr.masked_load operation to LLVM IR.
238 static LogicalResult
239 translateMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
240  LLVM::ModuleTranslation &moduleTranslation) {
241  llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
242  llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
243  llvm::Value *passthrough =
244  moduleTranslation.lookupValue(maskedLoadOp.getPassthrough());
245 
246  if (!ptr || !mask || !passthrough)
247  return maskedLoadOp.emitError("Failed to lookup operands");
248 
249  // Translate result type to LLVM type.
250  llvm::Type *resultType =
251  moduleTranslation.convertType(maskedLoadOp.getResult().getType());
252  if (!resultType)
253  return maskedLoadOp.emitError("Failed to translate result type");
254 
255  // Get the alignment.
256  llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
257 
258  // Create the masked load intrinsic call.
259  llvm::Value *result = builder.CreateMaskedLoad(
260  resultType, ptr, alignment.valueOrOne(), mask, passthrough);
261 
262  moduleTranslation.mapValue(maskedLoadOp.getResult(), result);
263  return success();
264 }
265 
266 /// Translate ptr.masked_store operation to LLVM IR.
267 static LogicalResult
268 translateMaskedStoreOp(MaskedStoreOp maskedStoreOp,
269  llvm::IRBuilderBase &builder,
270  LLVM::ModuleTranslation &moduleTranslation) {
271  llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
272  llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
273  llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
274 
275  if (!value || !ptr || !mask)
276  return maskedStoreOp.emitError("Failed to lookup operands");
277 
278  // Get the alignment.
279  llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0));
280 
281  // Create the masked store intrinsic call.
282  builder.CreateMaskedStore(value, ptr, alignment.valueOrOne(), mask);
283  return success();
284 }
285 
286 /// Translate ptr.scatter operation to LLVM IR.
287 static LogicalResult
288 translateScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
289  LLVM::ModuleTranslation &moduleTranslation) {
290  llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
291  llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
292  llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
293 
294  if (!value || !ptrs || !mask)
295  return scatterOp.emitError("Failed to lookup operands");
296 
297  // Get the alignment.
298  llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0));
299 
300  // Create the masked scatter intrinsic call.
301  builder.CreateMaskedScatter(value, ptrs, alignment.valueOrOne(), mask);
302  return success();
303 }
304 
305 /// Translate ptr.constant operation to LLVM IR.
306 static LogicalResult
307 translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder,
308  LLVM::ModuleTranslation &moduleTranslation) {
309  // Translate result type to LLVM type
310  llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>(
311  moduleTranslation.convertType(constantOp.getResult().getType()));
312  if (!resultType)
313  return constantOp.emitError("Expected a valid pointer type");
314 
315  llvm::Value *result = nullptr;
316 
317  TypedAttr value = constantOp.getValue();
318  if (auto nullAttr = dyn_cast<ptr::NullAttr>(value)) {
319  // Create a null pointer constant
320  result = llvm::ConstantPointerNull::get(resultType);
321  } else if (auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) {
322  // Create an integer constant and translate it to pointer
323  llvm::APInt addressValue = addressAttr.getValue();
324 
325  // Determine the integer type width based on the target's pointer size
326  llvm::DataLayout dataLayout =
327  moduleTranslation.getLLVMModule()->getDataLayout();
328  unsigned pointerSizeInBits =
329  dataLayout.getPointerSizeInBits(resultType->getAddressSpace());
330 
331  // Extend or truncate the address value to match pointer size if needed
332  if (addressValue.getBitWidth() != pointerSizeInBits) {
333  if (addressValue.getBitWidth() > pointerSizeInBits) {
334  constantOp.emitWarning()
335  << "Truncating address value to fit pointer size";
336  }
337  addressValue = addressValue.getBitWidth() < pointerSizeInBits
338  ? addressValue.zext(pointerSizeInBits)
339  : addressValue.trunc(pointerSizeInBits);
340  }
341 
342  // Create integer constant and translate to pointer
343  llvm::Type *intType = builder.getIntNTy(pointerSizeInBits);
344  llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue);
345  result = builder.CreateIntToPtr(intValue, resultType);
346  } else {
347  return constantOp.emitError("Unsupported constant attribute type");
348  }
349 
350  moduleTranslation.mapValue(constantOp.getResult(), result);
351  return success();
352 }
353 
354 /// Translate ptr.ptr_diff operation operation to LLVM IR.
355 static LogicalResult
356 translatePtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder,
357  LLVM::ModuleTranslation &moduleTranslation) {
358  llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs());
359  llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs());
360 
361  if (!lhs || !rhs)
362  return ptrDiffOp.emitError("Failed to lookup operands");
363 
364  // Translate result type to LLVM type
365  llvm::Type *resultType =
366  moduleTranslation.convertType(ptrDiffOp.getResult().getType());
367  if (!resultType)
368  return ptrDiffOp.emitError("Failed to translate result type");
369 
370  PtrDiffFlags flags = ptrDiffOp.getFlags();
371 
372  // Convert both pointers to integers using ptrtoaddr, and compute the
373  // difference: lhs - rhs
374  llvm::Value *llLhs = builder.CreatePtrToAddr(lhs);
375  llvm::Value *llRhs = builder.CreatePtrToAddr(rhs);
376  llvm::Value *result = builder.CreateSub(
377  llLhs, llRhs, /*Name=*/"",
378  /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw,
379  /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw);
380 
381  // Convert the difference to the expected result type by truncating or
382  // extending.
383  if (result->getType() != resultType)
384  result = builder.CreateIntCast(result, resultType, /*isSigned=*/true);
385 
386  moduleTranslation.mapValue(ptrDiffOp.getResult(), result);
387  return success();
388 }
389 
390 /// Implementation of the dialect interface that translates operations belonging
391 /// to the `ptr` dialect to LLVM IR.
392 class PtrDialectLLVMIRTranslationInterface
394 public:
396 
397  /// Translates the given operation to LLVM IR using the provided IR builder
398  /// and saving the state in `moduleTranslation`.
399  LogicalResult
400  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
401  LLVM::ModuleTranslation &moduleTranslation) const final {
402 
404  .Case([&](ConstantOp constantOp) {
405  return translateConstantOp(constantOp, builder, moduleTranslation);
406  })
407  .Case([&](PtrAddOp ptrAddOp) {
408  return translatePtrAddOp(ptrAddOp, builder, moduleTranslation);
409  })
410  .Case([&](PtrDiffOp ptrDiffOp) {
411  return translatePtrDiffOp(ptrDiffOp, builder, moduleTranslation);
412  })
413  .Case([&](LoadOp loadOp) {
414  return translateLoadOp(loadOp, builder, moduleTranslation);
415  })
416  .Case([&](StoreOp storeOp) {
417  return translateStoreOp(storeOp, builder, moduleTranslation);
418  })
419  .Case([&](TypeOffsetOp typeOffsetOp) {
420  return translateTypeOffsetOp(typeOffsetOp, builder,
421  moduleTranslation);
422  })
423  .Case<GatherOp>([&](GatherOp gatherOp) {
424  return translateGatherOp(gatherOp, builder, moduleTranslation);
425  })
426  .Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
427  return translateMaskedLoadOp(maskedLoadOp, builder,
428  moduleTranslation);
429  })
430  .Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
431  return translateMaskedStoreOp(maskedStoreOp, builder,
432  moduleTranslation);
433  })
434  .Case<ScatterOp>([&](ScatterOp scatterOp) {
435  return translateScatterOp(scatterOp, builder, moduleTranslation);
436  })
437  .Default([&](Operation *op) {
438  return op->emitError("Translation for operation '")
439  << op->getName() << "' is not implemented.";
440  });
441  }
442 
443  /// Attaches module-level metadata for functions marked as kernels.
444  LogicalResult
445  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
446  NamedAttribute attribute,
447  LLVM::ModuleTranslation &moduleTranslation) const final {
448  // No special amendments needed for ptr dialect operations
449  return success();
450  }
451 };
452 } // namespace
453 
455  registry.insert<ptr::PtrDialect>();
456  registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
457  dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>();
458  });
459 }
460 
462  DialectRegistry registry;
464  context.appendDialectRegistry(registry);
465 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
Include the generated interface declarations.
void registerPtrDialectTranslation(DialectRegistry &registry)
Register the ptr dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...