MLIR  22.0.0git
NVVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR NVVM dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Operation.h"
18 
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 using namespace mlir;
26 using namespace mlir::LLVM;
28 
29 #define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
30  hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
31  : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
32 
33 #define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
34  hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
35 
36 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
37  NVVM::ReduxKind kind,
38  bool hasAbs, bool hasNaN) {
39  if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
40  llvm_unreachable("unsupported data type for redux");
41 
42  switch (kind) {
43  case NVVM::ReduxKind::ADD:
44  return llvm::Intrinsic::nvvm_redux_sync_add;
45  case NVVM::ReduxKind::UMAX:
46  return llvm::Intrinsic::nvvm_redux_sync_umax;
47  case NVVM::ReduxKind::UMIN:
48  return llvm::Intrinsic::nvvm_redux_sync_umin;
49  case NVVM::ReduxKind::AND:
50  return llvm::Intrinsic::nvvm_redux_sync_and;
51  case NVVM::ReduxKind::OR:
52  return llvm::Intrinsic::nvvm_redux_sync_or;
53  case NVVM::ReduxKind::XOR:
54  return llvm::Intrinsic::nvvm_redux_sync_xor;
55  case NVVM::ReduxKind::MAX:
56  return llvm::Intrinsic::nvvm_redux_sync_max;
57  case NVVM::ReduxKind::MIN:
58  return llvm::Intrinsic::nvvm_redux_sync_min;
59  case NVVM::ReduxKind::FMIN:
60  return GET_REDUX_F32_ID(min, hasAbs, hasNaN);
61  case NVVM::ReduxKind::FMAX:
62  return GET_REDUX_F32_ID(max, hasAbs, hasNaN);
63  }
64  llvm_unreachable("unknown redux kind");
65 }
66 
67 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
68  NVVM::ShflKind kind,
69  bool withPredicate) {
70 
71  if (withPredicate) {
72  resultType = cast<llvm::StructType>(resultType)->getElementType(0);
73  switch (kind) {
74  case NVVM::ShflKind::bfly:
75  return resultType->isFloatTy()
76  ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
77  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
78  case NVVM::ShflKind::up:
79  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
80  : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
81  case NVVM::ShflKind::down:
82  return resultType->isFloatTy()
83  ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
84  : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
85  case NVVM::ShflKind::idx:
86  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
87  : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
88  }
89  } else {
90  switch (kind) {
91  case NVVM::ShflKind::bfly:
92  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
93  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
94  case NVVM::ShflKind::up:
95  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
96  : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
97  case NVVM::ShflKind::down:
98  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
99  : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
100  case NVVM::ShflKind::idx:
101  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
102  : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
103  }
104  }
105  llvm_unreachable("unknown shuffle kind");
106 }
107 
109  NVVM::MatchSyncKind kind) {
110  switch (kind) {
111  case NVVM::MatchSyncKind::any:
112  return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
113  : llvm::Intrinsic::nvvm_match_any_sync_i64;
114  case NVVM::MatchSyncKind::all:
115  // match.all instruction has two variants -- one returns a single value,
116  // another returns a pair {value, predicate}. We currently only implement
117  // the latter as that's the variant exposed by CUDA API.
118  return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
119  : llvm::Intrinsic::nvvm_match_all_sync_i64p;
120  }
121  llvm_unreachable("unsupported match sync kind");
122 }
123 
124 static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
125  switch (kind) {
126  case NVVM::VoteSyncKind::any:
127  return llvm::Intrinsic::nvvm_vote_any_sync;
128  case NVVM::VoteSyncKind::all:
129  return llvm::Intrinsic::nvvm_vote_all_sync;
130  case NVVM::VoteSyncKind::ballot:
131  return llvm::Intrinsic::nvvm_vote_ballot_sync;
132  case NVVM::VoteSyncKind::uni:
133  return llvm::Intrinsic::nvvm_vote_uni_sync;
134  }
135  llvm_unreachable("unsupported vote kind");
136 }
137 
138 static llvm::Intrinsic::ID
139 getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
140  NVVM::LdStMatrixShapeAttr shape,
141  NVVM::LdStMatrixEltType eltType) {
142  if (shape.getM() == 8 && shape.getN() == 8) {
143  switch (num) {
144  case 1:
145  return (layout == NVVM::MMALayout::row)
146  ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
147  : llvm::Intrinsic::
148  nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
149  case 2:
150  return (layout == NVVM::MMALayout::row)
151  ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
152  : llvm::Intrinsic::
153  nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
154  case 4:
155  return (layout == NVVM::MMALayout::row)
156  ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
157  : llvm::Intrinsic::
158  nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
159  }
160  } else if (shape.getM() == 8 && shape.getN() == 16) {
161  if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
162  switch (num) {
163  case 1:
164  return llvm::Intrinsic::
165  nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
166  case 2:
167  return llvm::Intrinsic::
168  nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
169  case 4:
170  return llvm::Intrinsic::
171  nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
172  }
173  } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
174  switch (num) {
175  case 1:
176  return llvm::Intrinsic::
177  nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
178  case 2:
179  return llvm::Intrinsic::
180  nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
181  case 4:
182  return llvm::Intrinsic::
183  nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
184  }
185  }
186  } else if (shape.getM() == 16 && shape.getN() == 16) {
187  if (eltType == NVVM::LdStMatrixEltType::B8) {
188  switch (num) {
189  case 1:
190  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
191  case 2:
192  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
193  }
194  } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
195  switch (num) {
196  case 1:
197  return llvm::Intrinsic::
198  nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
199  case 2:
200  return llvm::Intrinsic::
201  nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
202  }
203  } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
204  switch (num) {
205  case 1:
206  return llvm::Intrinsic::
207  nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
208  case 2:
209  return llvm::Intrinsic::
210  nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
211  }
212  }
213  }
214  llvm_unreachable("unknown ldmatrix kind");
215 }
216 
217 /// Return the intrinsic ID associated with stmatrix for the given paramters.
218 static llvm::Intrinsic::ID
219 getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
220  NVVM::LdStMatrixShapeAttr shape,
221  NVVM::LdStMatrixEltType eltType) {
222  if (shape.getM() == 8 && shape.getN() == 8) {
223  switch (num) {
224  case 1:
225  return (layout == NVVM::MMALayout::row)
226  ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
227  : llvm::Intrinsic::
228  nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
229  case 2:
230  return (layout == NVVM::MMALayout::row)
231  ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
232  : llvm::Intrinsic::
233  nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
234  case 4:
235  return (layout == NVVM::MMALayout::row)
236  ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
237  : llvm::Intrinsic::
238  nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
239  }
240  } else if (shape.getM() == 16 && shape.getN() == 8) {
241  switch (num) {
242  case 1:
243  return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
244  case 2:
245  return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
246  case 4:
247  return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
248  }
249  }
250  llvm_unreachable("unknown stmatrix kind");
251 }
252 
253 /// Return the intrinsic ID associated with st.bulk for the given address type.
254 static llvm::Intrinsic::ID
255 getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
256  bool isSharedMemory = addrType.getAddressSpace() ==
257  static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
258  return isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta
259  : llvm::Intrinsic::nvvm_st_bulk;
260 }
261 
262 static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
263  NVVM::ProxyKind toProxy,
264  NVVM::MemScopeKind scope,
265  bool isRelease) {
266  if (fromProxy == NVVM::ProxyKind::GENERIC &&
267  toProxy == NVVM::ProxyKind::TENSORMAP) {
268  switch (scope) {
269  case NVVM::MemScopeKind::CTA: {
270  if (isRelease)
271  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
272  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
273  }
274  case NVVM::MemScopeKind::CLUSTER: {
275  if (isRelease)
276  return llvm::Intrinsic::
277  nvvm_fence_proxy_tensormap_generic_release_cluster;
278  return llvm::Intrinsic::
279  nvvm_fence_proxy_tensormap_generic_acquire_cluster;
280  }
281  case NVVM::MemScopeKind::GPU: {
282  if (isRelease)
283  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
284  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
285  }
286  case NVVM::MemScopeKind::SYS: {
287  if (isRelease)
288  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
289  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
290  }
291  }
292  llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
293  }
294  llvm_unreachable("Unsupported proxy kinds");
295 }
296 
297 #define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
298 
299 static llvm::Intrinsic::ID
300 getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
301  llvm::Intrinsic::ID Shape16x64b[] = {
302  TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4),
303  TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32),
304  TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128),
305  };
306 
307  llvm::Intrinsic::ID Shape16x128b[] = {
308  TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4),
309  TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32),
310  TCGEN05LD(16x128b, x64),
311  };
312 
313  llvm::Intrinsic::ID Shape16x256b[] = {
314  TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4),
315  TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32),
316  };
317 
318  llvm::Intrinsic::ID Shape16x32bx2[] = {
319  TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2),
320  TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8),
321  TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32),
322  TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128),
323  };
324 
325  llvm::Intrinsic::ID Shape32x32b[] = {
326  TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4),
327  TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32),
328  TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128),
329  };
330 
331  // `num` contains the length of vector and log2 of `num` returns the index
332  // into the shape array
333  unsigned Idx = std::log2(num);
334 
335  switch (shape) {
336  case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
337  return Shape16x64b[Idx];
338  case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
339  return Shape16x128b[Idx - 1];
340  case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
341  return Shape16x256b[Idx - 2];
342  case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
343  return Shape32x32b[Idx];
344  case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
345  return Shape16x32bx2[Idx];
346  }
347  llvm_unreachable("unhandled tcgen05.ld lowering");
348 }
349 
350 #define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
351 
352 static llvm::Intrinsic::ID
353 getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
354  llvm::Intrinsic::ID Shape16x64b[] = {
355  TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4),
356  TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32),
357  TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128),
358  };
359 
360  llvm::Intrinsic::ID Shape16x128b[] = {
361  TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4),
362  TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32),
363  TCGEN05ST(16x128b, x64),
364  };
365 
366  llvm::Intrinsic::ID Shape16x256b[] = {
367  TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4),
368  TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32),
369  };
370 
371  llvm::Intrinsic::ID Shape16x32bx2[] = {
372  TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2),
373  TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8),
374  TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32),
375  TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128),
376  };
377 
378  llvm::Intrinsic::ID Shape32x32b[] = {
379  TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4),
380  TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32),
381  TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128),
382  };
383 
384  // `num` contains the length of vector and log2 of `num` returns the index
385  // into the shape array
386  unsigned Idx = std::log2(num);
387 
388  switch (shape) {
389  case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
390  return Shape16x64b[Idx];
391  case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
392  return Shape16x128b[Idx - 1];
393  case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
394  return Shape16x256b[Idx - 2];
395  case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
396  return Shape32x32b[Idx];
397  case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
398  return Shape16x32bx2[Idx];
399  }
400  llvm_unreachable("unhandled tcgen05.st lowering");
401 }
402 
403 namespace {
404 /// Implementation of the dialect interface that converts operations belonging
405 /// to the NVVM dialect to LLVM IR.
406 class NVVMDialectLLVMIRTranslationInterface
408 public:
410 
411  /// Translates the given operation to LLVM IR using the provided IR builder
412  /// and saving the state in `moduleTranslation`.
413  LogicalResult
414  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
415  LLVM::ModuleTranslation &moduleTranslation) const final {
416  Operation &opInst = *op;
417 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
418 
419  return failure();
420  }
421 
422  /// Attaches module-level metadata for functions marked as kernels.
423  LogicalResult
424  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
425  NamedAttribute attribute,
426  LLVM::ModuleTranslation &moduleTranslation) const final {
427  auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
428  if (!func)
429  return failure();
430  llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
431 
432  if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
433  if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
434  return failure();
435  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
436  const std::string attr = llvm::formatv(
437  "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
438  values.asArrayRef().end()));
439  llvmFunc->addFnAttr("nvvm.maxntid", attr);
440  } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
441  if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
442  return failure();
443  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
444  const std::string attr = llvm::formatv(
445  "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
446  values.asArrayRef().end()));
447  llvmFunc->addFnAttr("nvvm.reqntid", attr);
448  } else if (attribute.getName() ==
449  NVVM::NVVMDialect::getClusterDimAttrName()) {
450  if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
451  return failure();
452  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
453  const std::string attr = llvm::formatv(
454  "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
455  values.asArrayRef().end()));
456  llvmFunc->addFnAttr("nvvm.cluster_dim", attr);
457  } else if (attribute.getName() ==
458  NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
459  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
460  llvmFunc->addFnAttr("nvvm.maxclusterrank", llvm::utostr(value.getInt()));
461  } else if (attribute.getName() ==
462  NVVM::NVVMDialect::getMinctasmAttrName()) {
463  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
464  llvmFunc->addFnAttr("nvvm.minctasm", llvm::utostr(value.getInt()));
465  } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
466  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
467  llvmFunc->addFnAttr("nvvm.maxnreg", llvm::utostr(value.getInt()));
468  } else if (attribute.getName() ==
469  NVVM::NVVMDialect::getKernelFuncAttrName()) {
470  llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
471  } else if (attribute.getName() ==
472  NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
473  llvmFunc->addFnAttr("nvvm.blocksareclusters");
474  }
475 
476  return success();
477  }
478 
479  LogicalResult
480  convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
481  LLVM::ModuleTranslation &moduleTranslation) const final {
482 
483  llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
484  llvm::Function *llvmFunc =
485  moduleTranslation.lookupFunction(funcOp.getName());
486 
487  if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
488  llvmFunc->addParamAttr(
489  argIdx, llvm::Attribute::get(llvmContext, "nvvm.grid_constant"));
490  }
491  return success();
492  }
493 };
494 } // namespace
495 
497  registry.insert<NVVM::NVVMDialect>();
498  registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
499  dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
500  });
501 }
502 
504  DialectRegistry registry;
506  context.appendDialectRegistry(registry);
507 }
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind, bool hasAbs, bool hasNaN)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry &registry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...