MLIR  22.0.0git
PtrToLLVMIRTranslation.cpp
Go to the documentation of this file.
1 //===- PtrToLLVMIRTranslation.cpp - Translate `ptr` to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR `ptr` dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Type.h"
23 #include "llvm/IR/Value.h"
24 
25 using namespace mlir;
26 using namespace mlir::ptr;
27 
28 namespace {
29 
30 /// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
31 static llvm::AtomicOrdering
32 convertAtomicOrdering(ptr::AtomicOrdering ordering) {
33  switch (ordering) {
34  case ptr::AtomicOrdering::not_atomic:
35  return llvm::AtomicOrdering::NotAtomic;
36  case ptr::AtomicOrdering::unordered:
37  return llvm::AtomicOrdering::Unordered;
38  case ptr::AtomicOrdering::monotonic:
39  return llvm::AtomicOrdering::Monotonic;
40  case ptr::AtomicOrdering::acquire:
41  return llvm::AtomicOrdering::Acquire;
42  case ptr::AtomicOrdering::release:
43  return llvm::AtomicOrdering::Release;
44  case ptr::AtomicOrdering::acq_rel:
45  return llvm::AtomicOrdering::AcquireRelease;
46  case ptr::AtomicOrdering::seq_cst:
47  return llvm::AtomicOrdering::SequentiallyConsistent;
48  }
49  llvm_unreachable("Unknown atomic ordering");
50 }
51 
52 /// Convert ptr.ptr_add operation
53 static LogicalResult
54 convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
55  LLVM::ModuleTranslation &moduleTranslation) {
56  llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
57  llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
58 
59  if (!basePtr || !offset)
60  return ptrAddOp.emitError("Failed to lookup operands");
61 
62  // Create the GEP flags
63  llvm::GEPNoWrapFlags gepFlags;
64  switch (ptrAddOp.getFlags()) {
65  case ptr::PtrAddFlags::none:
66  break;
67  case ptr::PtrAddFlags::nusw:
68  gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap();
69  break;
70  case ptr::PtrAddFlags::nuw:
71  gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap();
72  break;
73  case ptr::PtrAddFlags::inbounds:
74  gepFlags = llvm::GEPNoWrapFlags::inBounds();
75  break;
76  }
77 
78  // Create GEP instruction for pointer arithmetic
79  llvm::Value *gep =
80  builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset}, "", gepFlags);
81 
82  moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
83  return success();
84 }
85 
86 /// Convert ptr.load operation
87 static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
88  LLVM::ModuleTranslation &moduleTranslation) {
89  llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
90  if (!ptr)
91  return loadOp.emitError("Failed to lookup pointer operand");
92 
93  // Convert result type to LLVM type
94  llvm::Type *resultType =
95  moduleTranslation.convertType(loadOp.getValue().getType());
96  if (!resultType)
97  return loadOp.emitError("Failed to convert result type");
98 
99  // Create the load instruction.
100  llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0));
101  llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
102  resultType, ptr, alignment, loadOp.getVolatile_());
103 
104  // Set op flags and metadata.
105  loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
106  // Set sync scope if specified
107  if (loadOp.getSyncscope().has_value()) {
108  llvm::LLVMContext &ctx = builder.getContext();
109  llvm::SyncScope::ID syncScope =
110  ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
111  loadInst->setSyncScopeID(syncScope);
112  }
113 
114  // Set metadata for nontemporal, invariant, and invariant_group
115  if (loadOp.getNontemporal()) {
116  llvm::MDNode *nontemporalMD =
117  llvm::MDNode::get(builder.getContext(),
118  llvm::ConstantAsMetadata::get(builder.getInt32(1)));
119  loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
120  }
121 
122  if (loadOp.getInvariant()) {
123  llvm::MDNode *invariantMD = llvm::MDNode::get(builder.getContext(), {});
124  loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
125  }
126 
127  if (loadOp.getInvariantGroup()) {
128  llvm::MDNode *invariantGroupMD =
129  llvm::MDNode::get(builder.getContext(), {});
130  loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
131  invariantGroupMD);
132  }
133 
134  moduleTranslation.mapValue(loadOp.getResult(), loadInst);
135  return success();
136 }
137 
138 /// Convert ptr.store operation
139 static LogicalResult
140 convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
141  LLVM::ModuleTranslation &moduleTranslation) {
142  llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
143  llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
144 
145  if (!value || !ptr)
146  return storeOp.emitError("Failed to lookup operands");
147 
148  // Create the store instruction.
149  llvm::MaybeAlign alignment(storeOp.getAlignment().value_or(0));
150  llvm::StoreInst *storeInst =
151  builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
152 
153  // Set op flags and metadata.
154  storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
155  // Set sync scope if specified
156  if (storeOp.getSyncscope().has_value()) {
157  llvm::LLVMContext &ctx = builder.getContext();
158  llvm::SyncScope::ID syncScope =
159  ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
160  storeInst->setSyncScopeID(syncScope);
161  }
162 
163  // Set metadata for nontemporal and invariant_group
164  if (storeOp.getNontemporal()) {
165  llvm::MDNode *nontemporalMD =
166  llvm::MDNode::get(builder.getContext(),
167  llvm::ConstantAsMetadata::get(builder.getInt32(1)));
168  storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
169  }
170 
171  if (storeOp.getInvariantGroup()) {
172  llvm::MDNode *invariantGroupMD =
173  llvm::MDNode::get(builder.getContext(), {});
174  storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
175  invariantGroupMD);
176  }
177 
178  return success();
179 }
180 
181 /// Convert ptr.type_offset operation
182 static LogicalResult
183 convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
184  LLVM::ModuleTranslation &moduleTranslation) {
185  // Convert the element type to LLVM type
186  llvm::Type *elementType =
187  moduleTranslation.convertType(typeOffsetOp.getElementType());
188  if (!elementType)
189  return typeOffsetOp.emitError("Failed to convert the element type");
190 
191  // Convert result type
192  llvm::Type *resultType =
193  moduleTranslation.convertType(typeOffsetOp.getResult().getType());
194  if (!resultType)
195  return typeOffsetOp.emitError("Failed to convert the result type");
196 
197  // Use GEP with null pointer to compute type size/offset.
198  llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
199  llvm::Value *offsetPtr =
200  builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
201  llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);
202 
203  moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
204  return success();
205 }
206 
207 /// Convert ptr.gather operation
208 static LogicalResult
209 convertGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder,
210  LLVM::ModuleTranslation &moduleTranslation) {
211  llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs());
212  llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask());
213  llvm::Value *passthrough =
214  moduleTranslation.lookupValue(gatherOp.getPassthrough());
215 
216  if (!ptrs || !mask || !passthrough)
217  return gatherOp.emitError("Failed to lookup operands");
218 
219  // Convert result type to LLVM type.
220  llvm::Type *resultType =
221  moduleTranslation.convertType(gatherOp.getResult().getType());
222  if (!resultType)
223  return gatherOp.emitError("Failed to convert result type");
224 
225  // Get the alignment.
226  llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0));
227 
228  // Create the masked gather intrinsic call.
229  llvm::Value *result = builder.CreateMaskedGather(
230  resultType, ptrs, alignment.valueOrOne(), mask, passthrough);
231 
232  moduleTranslation.mapValue(gatherOp.getResult(), result);
233  return success();
234 }
235 
236 /// Convert ptr.masked_load operation
237 static LogicalResult
238 convertMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder,
239  LLVM::ModuleTranslation &moduleTranslation) {
240  llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr());
241  llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask());
242  llvm::Value *passthrough =
243  moduleTranslation.lookupValue(maskedLoadOp.getPassthrough());
244 
245  if (!ptr || !mask || !passthrough)
246  return maskedLoadOp.emitError("Failed to lookup operands");
247 
248  // Convert result type to LLVM type.
249  llvm::Type *resultType =
250  moduleTranslation.convertType(maskedLoadOp.getResult().getType());
251  if (!resultType)
252  return maskedLoadOp.emitError("Failed to convert result type");
253 
254  // Get the alignment.
255  llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0));
256 
257  // Create the masked load intrinsic call.
258  llvm::Value *result = builder.CreateMaskedLoad(
259  resultType, ptr, alignment.valueOrOne(), mask, passthrough);
260 
261  moduleTranslation.mapValue(maskedLoadOp.getResult(), result);
262  return success();
263 }
264 
265 /// Convert ptr.masked_store operation
266 static LogicalResult
267 convertMaskedStoreOp(MaskedStoreOp maskedStoreOp, llvm::IRBuilderBase &builder,
268  LLVM::ModuleTranslation &moduleTranslation) {
269  llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue());
270  llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr());
271  llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask());
272 
273  if (!value || !ptr || !mask)
274  return maskedStoreOp.emitError("Failed to lookup operands");
275 
276  // Get the alignment.
277  llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0));
278 
279  // Create the masked store intrinsic call.
280  builder.CreateMaskedStore(value, ptr, alignment.valueOrOne(), mask);
281  return success();
282 }
283 
284 /// Convert ptr.scatter operation
285 static LogicalResult
286 convertScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder,
287  LLVM::ModuleTranslation &moduleTranslation) {
288  llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue());
289  llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs());
290  llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask());
291 
292  if (!value || !ptrs || !mask)
293  return scatterOp.emitError("Failed to lookup operands");
294 
295  // Get the alignment.
296  llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0));
297 
298  // Create the masked scatter intrinsic call.
299  builder.CreateMaskedScatter(value, ptrs, alignment.valueOrOne(), mask);
300  return success();
301 }
302 
303 /// Implementation of the dialect interface that converts operations belonging
304 /// to the `ptr` dialect to LLVM IR.
305 class PtrDialectLLVMIRTranslationInterface
307 public:
309 
310  /// Translates the given operation to LLVM IR using the provided IR builder
311  /// and saving the state in `moduleTranslation`.
312  LogicalResult
313  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
314  LLVM::ModuleTranslation &moduleTranslation) const final {
315 
317  .Case([&](PtrAddOp ptrAddOp) {
318  return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
319  })
320  .Case([&](LoadOp loadOp) {
321  return convertLoadOp(loadOp, builder, moduleTranslation);
322  })
323  .Case([&](StoreOp storeOp) {
324  return convertStoreOp(storeOp, builder, moduleTranslation);
325  })
326  .Case([&](TypeOffsetOp typeOffsetOp) {
327  return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
328  })
329  .Case<GatherOp>([&](GatherOp gatherOp) {
330  return convertGatherOp(gatherOp, builder, moduleTranslation);
331  })
332  .Case<MaskedLoadOp>([&](MaskedLoadOp maskedLoadOp) {
333  return convertMaskedLoadOp(maskedLoadOp, builder, moduleTranslation);
334  })
335  .Case<MaskedStoreOp>([&](MaskedStoreOp maskedStoreOp) {
336  return convertMaskedStoreOp(maskedStoreOp, builder,
337  moduleTranslation);
338  })
339  .Case<ScatterOp>([&](ScatterOp scatterOp) {
340  return convertScatterOp(scatterOp, builder, moduleTranslation);
341  })
342  .Default([&](Operation *op) {
343  return op->emitError("Translation for operation '")
344  << op->getName() << "' is not implemented.";
345  });
346  }
347 
348  /// Attaches module-level metadata for functions marked as kernels.
349  LogicalResult
350  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
351  NamedAttribute attribute,
352  LLVM::ModuleTranslation &moduleTranslation) const final {
353  // No special amendments needed for ptr dialect operations
354  return success();
355  }
356 };
357 } // namespace
358 
360  registry.insert<ptr::PtrDialect>();
361  registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
362  dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>();
363  });
364 }
365 
367  DialectRegistry registry;
369  context.appendDialectRegistry(registry);
370 }
static llvm::AtomicOrdering convertAtomicOrdering(std::optional< omp::ClauseMemoryOrderKind > ao)
Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
Include the generated interface declarations.
void registerPtrDialectTranslation(DialectRegistry &registry)
Register the ptr dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...