MLIR 23.0.0git
NVVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR NVVM dialect and
10// LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/IR/Operation.h"
18
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicsNVPTX.h"
23#include "llvm/Support/FormatVariadic.h"
24
25using namespace mlir;
26using namespace mlir::LLVM;
28
29#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
30 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
31 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
32
33#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
34 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
35
36static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
37 NVVM::ReductionKind kind,
38 bool hasAbs, bool hasNaN) {
39 switch (kind) {
40 case NVVM::ReductionKind::ADD:
41 return llvm::Intrinsic::nvvm_redux_sync_add;
42 case NVVM::ReductionKind::UMAX:
43 return llvm::Intrinsic::nvvm_redux_sync_umax;
44 case NVVM::ReductionKind::UMIN:
45 return llvm::Intrinsic::nvvm_redux_sync_umin;
46 case NVVM::ReductionKind::AND:
47 return llvm::Intrinsic::nvvm_redux_sync_and;
48 case NVVM::ReductionKind::OR:
49 return llvm::Intrinsic::nvvm_redux_sync_or;
50 case NVVM::ReductionKind::XOR:
51 return llvm::Intrinsic::nvvm_redux_sync_xor;
52 case NVVM::ReductionKind::MAX:
53 return llvm::Intrinsic::nvvm_redux_sync_max;
54 case NVVM::ReductionKind::MIN:
55 return llvm::Intrinsic::nvvm_redux_sync_min;
56 case NVVM::ReductionKind::FMIN:
57 return GET_REDUX_F32_ID(min, hasAbs, hasNaN);
58 case NVVM::ReductionKind::FMAX:
59 return GET_REDUX_F32_ID(max, hasAbs, hasNaN);
60 }
61 llvm_unreachable("unknown reduction kind");
62}
63
64static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
65 NVVM::ShflKind kind,
66 bool withPredicate) {
67
68 if (withPredicate) {
69 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
70 switch (kind) {
71 case NVVM::ShflKind::bfly:
72 return resultType->isFloatTy()
73 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
74 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
75 case NVVM::ShflKind::up:
76 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
77 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
78 case NVVM::ShflKind::down:
79 return resultType->isFloatTy()
80 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
81 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
82 case NVVM::ShflKind::idx:
83 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
84 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
85 }
86 } else {
87 switch (kind) {
88 case NVVM::ShflKind::bfly:
89 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
90 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
91 case NVVM::ShflKind::up:
92 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
93 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
94 case NVVM::ShflKind::down:
95 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
96 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
97 case NVVM::ShflKind::idx:
98 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
99 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
100 }
101 }
102 llvm_unreachable("unknown shuffle kind");
103}
104
105static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
106 NVVM::MatchSyncKind kind) {
107 switch (kind) {
108 case NVVM::MatchSyncKind::any:
109 return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
110 : llvm::Intrinsic::nvvm_match_any_sync_i64;
111 case NVVM::MatchSyncKind::all:
112 // match.all instruction has two variants -- one returns a single value,
113 // another returns a pair {value, predicate}. We currently only implement
114 // the latter as that's the variant exposed by CUDA API.
115 return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
116 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
117 }
118 llvm_unreachable("unsupported match sync kind");
119}
120
121static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
122 switch (kind) {
123 case NVVM::VoteSyncKind::any:
124 return llvm::Intrinsic::nvvm_vote_any_sync;
125 case NVVM::VoteSyncKind::all:
126 return llvm::Intrinsic::nvvm_vote_all_sync;
127 case NVVM::VoteSyncKind::ballot:
128 return llvm::Intrinsic::nvvm_vote_ballot_sync;
129 case NVVM::VoteSyncKind::uni:
130 return llvm::Intrinsic::nvvm_vote_uni_sync;
131 }
132 llvm_unreachable("unsupported vote kind");
133}
134
135static llvm::Intrinsic::ID
136getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
137 NVVM::LdStMatrixShapeAttr shape,
138 NVVM::LdStMatrixEltType eltType) {
139 if (shape.getM() == 8 && shape.getN() == 8) {
140 switch (num) {
141 case 1:
142 return (layout == NVVM::MMALayout::row)
143 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
144 : llvm::Intrinsic::
145 nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
146 case 2:
147 return (layout == NVVM::MMALayout::row)
148 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
149 : llvm::Intrinsic::
150 nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
151 case 4:
152 return (layout == NVVM::MMALayout::row)
153 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
154 : llvm::Intrinsic::
155 nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
156 }
157 } else if (shape.getM() == 8 && shape.getN() == 16) {
158 if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
159 switch (num) {
160 case 1:
161 return llvm::Intrinsic::
162 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
163 case 2:
164 return llvm::Intrinsic::
165 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
166 case 4:
167 return llvm::Intrinsic::
168 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
169 }
170 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
171 switch (num) {
172 case 1:
173 return llvm::Intrinsic::
174 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
175 case 2:
176 return llvm::Intrinsic::
177 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
178 case 4:
179 return llvm::Intrinsic::
180 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
181 }
182 }
183 } else if (shape.getM() == 16 && shape.getN() == 16) {
184 if (eltType == NVVM::LdStMatrixEltType::B8) {
185 switch (num) {
186 case 1:
187 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
188 case 2:
189 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
190 }
191 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
192 switch (num) {
193 case 1:
194 return llvm::Intrinsic::
195 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
196 case 2:
197 return llvm::Intrinsic::
198 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
199 }
200 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
201 switch (num) {
202 case 1:
203 return llvm::Intrinsic::
204 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
205 case 2:
206 return llvm::Intrinsic::
207 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
208 }
209 }
210 }
211 llvm_unreachable("unknown ldmatrix kind");
212}
213
214/// Return the intrinsic ID associated with stmatrix for the given paramters.
215static llvm::Intrinsic::ID
216getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
217 NVVM::LdStMatrixShapeAttr shape,
218 NVVM::LdStMatrixEltType eltType) {
219 if (shape.getM() == 8 && shape.getN() == 8) {
220 switch (num) {
221 case 1:
222 return (layout == NVVM::MMALayout::row)
223 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
224 : llvm::Intrinsic::
225 nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
226 case 2:
227 return (layout == NVVM::MMALayout::row)
228 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
229 : llvm::Intrinsic::
230 nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
231 case 4:
232 return (layout == NVVM::MMALayout::row)
233 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
234 : llvm::Intrinsic::
235 nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
236 }
237 } else if (shape.getM() == 16 && shape.getN() == 8) {
238 switch (num) {
239 case 1:
240 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
241 case 2:
242 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
243 case 4:
244 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
245 }
246 }
247 llvm_unreachable("unknown stmatrix kind");
248}
249
250/// Return the intrinsic ID associated with st.bulk for the given address type.
251static llvm::Intrinsic::ID
252getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
253 bool isSharedMemory = addrType.getAddressSpace() ==
254 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
255 return isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta
256 : llvm::Intrinsic::nvvm_st_bulk;
257}
258
259static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
260 NVVM::ProxyKind toProxy,
261 NVVM::MemScopeKind scope,
262 bool isRelease) {
263 if (fromProxy == NVVM::ProxyKind::GENERIC &&
264 toProxy == NVVM::ProxyKind::TENSORMAP) {
265 switch (scope) {
266 case NVVM::MemScopeKind::CTA: {
267 if (isRelease)
268 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
269 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
270 }
271 case NVVM::MemScopeKind::CLUSTER: {
272 if (isRelease)
273 return llvm::Intrinsic::
274 nvvm_fence_proxy_tensormap_generic_release_cluster;
275 return llvm::Intrinsic::
276 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
277 }
278 case NVVM::MemScopeKind::GPU: {
279 if (isRelease)
280 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
281 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
282 }
283 case NVVM::MemScopeKind::SYS: {
284 if (isRelease)
285 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
286 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
287 }
288 }
289 llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
290 }
291 llvm_unreachable("Unsupported proxy kinds");
292}
293
294static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope) {
295 switch (scope) {
296 case NVVM::MemScopeKind::CTA:
297 return llvm::Intrinsic::nvvm_membar_cta;
298 case NVVM::MemScopeKind::CLUSTER:
299 return llvm::Intrinsic::nvvm_fence_sc_cluster;
300 case NVVM::MemScopeKind::GPU:
301 return llvm::Intrinsic::nvvm_membar_gl;
302 case NVVM::MemScopeKind::SYS:
303 return llvm::Intrinsic::nvvm_membar_sys;
304 }
305 llvm_unreachable("Unknown scope for memory barrier");
306}
307
308#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
309
310static llvm::Intrinsic::ID
311getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
312 llvm::Intrinsic::ID Shape16x64b[] = {
313 TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4),
314 TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32),
315 TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128),
316 };
317
318 llvm::Intrinsic::ID Shape16x128b[] = {
319 TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4),
320 TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32),
321 TCGEN05LD(16x128b, x64),
322 };
323
324 llvm::Intrinsic::ID Shape16x256b[] = {
325 TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4),
326 TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32),
327 };
328
329 llvm::Intrinsic::ID Shape16x32bx2[] = {
330 TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2),
331 TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8),
332 TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32),
333 TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128),
334 };
335
336 llvm::Intrinsic::ID Shape32x32b[] = {
337 TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4),
338 TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32),
339 TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128),
340 };
341
342 // `num` contains the length of vector and log2 of `num` returns the index
343 // into the shape array
344 unsigned Idx = std::log2(num);
345
346 switch (shape) {
347 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
348 return Shape16x64b[Idx];
349 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
350 return Shape16x128b[Idx - 1];
351 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
352 return Shape16x256b[Idx - 2];
353 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
354 return Shape32x32b[Idx];
355 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
356 return Shape16x32bx2[Idx];
357 }
358 llvm_unreachable("unhandled tcgen05.ld lowering");
359}
360
361#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
362
363static llvm::Intrinsic::ID
364getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
365 llvm::Intrinsic::ID Shape16x64b[] = {
366 TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4),
367 TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32),
368 TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128),
369 };
370
371 llvm::Intrinsic::ID Shape16x128b[] = {
372 TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4),
373 TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32),
374 TCGEN05ST(16x128b, x64),
375 };
376
377 llvm::Intrinsic::ID Shape16x256b[] = {
378 TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4),
379 TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32),
380 };
381
382 llvm::Intrinsic::ID Shape16x32bx2[] = {
383 TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2),
384 TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8),
385 TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32),
386 TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128),
387 };
388
389 llvm::Intrinsic::ID Shape32x32b[] = {
390 TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4),
391 TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32),
392 TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128),
393 };
394
395 // `num` contains the length of vector and log2 of `num` returns the index
396 // into the shape array
397 unsigned Idx = std::log2(num);
398
399 switch (shape) {
400 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
401 return Shape16x64b[Idx];
402 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
403 return Shape16x128b[Idx - 1];
404 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
405 return Shape16x256b[Idx - 2];
406 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
407 return Shape32x32b[Idx];
408 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
409 return Shape16x32bx2[Idx];
410 }
411 llvm_unreachable("unhandled tcgen05.st lowering");
412}
413
414static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order) {
415 return order == NVVM::MemOrderKind::ACQUIRE
416 ? llvm::Intrinsic::
417 nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster
418 : llvm::Intrinsic::
419 nvvm_fence_release_sync_restrict_space_cta_scope_cluster;
420}
421
422static llvm::Intrinsic::ID
423getFenceProxyID(NVVM::ProxyKind kind, std::optional<NVVM::SharedSpace> space) {
424 switch (kind) {
425 case NVVM::ProxyKind::alias:
426 return llvm::Intrinsic::nvvm_fence_proxy_alias;
427 case NVVM::ProxyKind::async:
428 return llvm::Intrinsic::nvvm_fence_proxy_async;
429 case NVVM::ProxyKind::async_global:
430 return llvm::Intrinsic::nvvm_fence_proxy_async_global;
431 case NVVM::ProxyKind::async_shared:
432 return *space == NVVM::SharedSpace::shared_cta
433 ? llvm::Intrinsic::nvvm_fence_proxy_async_shared_cta
434 : llvm::Intrinsic::nvvm_fence_proxy_async_shared_cluster;
435 default:
436 llvm_unreachable("unsupported proxy kind");
437 }
438}
439
440static llvm::Intrinsic::ID
441getFenceProxySyncRestrictID(NVVM::MemOrderKind order) {
442 return order == NVVM::MemOrderKind::ACQUIRE
443 ? llvm::Intrinsic::
444 nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster
445 : llvm::Intrinsic::
446 nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster;
447}
448
449// Calls an LLVM intrinsic on the given operands. For f32/f64 vector types,
450// the intrinsic is called per-element and the results are packed back into a
451// vector. If retType is non-null, it is forwarded as the return-type
452// overload to `createIntrinsicCall`.
453static llvm::Value *
454createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder,
455 llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM,
457 llvm::Type *retType) {
458 if (opTypeLLVM->isVectorTy() && (opTypeLLVM->getScalarType()->isFloatTy() ||
459 opTypeLLVM->getScalarType()->isDoubleTy())) {
460 llvm::Value *result = llvm::PoisonValue::get(
461 llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
462 for (int64_t i = 0; i < 2; ++i) {
464 for (llvm::Value *op : operands)
465 scalarArgs.push_back(
466 builder.CreateExtractElement(op, builder.getInt32(i)));
467 llvm::Value *res = createIntrinsicCall(builder, IID, retType, scalarArgs);
468 result = builder.CreateInsertElement(result, res, builder.getInt32(i));
469 }
470 return result;
471 }
472
473 return createIntrinsicCall(builder, IID, retType, operands);
474}
475
476void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
477 Value res, NVVM::FPRoundingMode rndMode,
478 NVVM::SaturationMode satMode, bool isFTZ,
480 llvm::IRBuilderBase &builder) {
481 llvm::Type *opTypeLLVM = argLHS->getType();
482 bool isVectorOp = opTypeLLVM->isVectorTy();
483 bool isSat = satMode != NVVM::SaturationMode::NONE;
484
485 // FIXME: Add intrinsics for add.rn.ftz.f16x2 and add.rn.ftz.f16 here when
486 // they are available.
487 static constexpr llvm::Intrinsic::ID f16IDs[] = {
488 llvm::Intrinsic::nvvm_add_rn_sat_f16,
489 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f16,
490 llvm::Intrinsic::nvvm_add_rn_sat_v2f16,
491 llvm::Intrinsic::nvvm_add_rn_ftz_sat_v2f16,
492 };
493
494 static constexpr llvm::Intrinsic::ID f32IDs[] = {
495 llvm::Intrinsic::nvvm_add_rn_f, // default rounding mode RN
496 llvm::Intrinsic::nvvm_add_rn_f,
497 llvm::Intrinsic::nvvm_add_rm_f,
498 llvm::Intrinsic::nvvm_add_rp_f,
499 llvm::Intrinsic::nvvm_add_rz_f,
500 llvm::Intrinsic::nvvm_add_rn_sat_f, // default rounding mode RN
501 llvm::Intrinsic::nvvm_add_rn_sat_f,
502 llvm::Intrinsic::nvvm_add_rm_sat_f,
503 llvm::Intrinsic::nvvm_add_rp_sat_f,
504 llvm::Intrinsic::nvvm_add_rz_sat_f,
505 llvm::Intrinsic::nvvm_add_rn_ftz_f, // default rounding mode RN
506 llvm::Intrinsic::nvvm_add_rn_ftz_f,
507 llvm::Intrinsic::nvvm_add_rm_ftz_f,
508 llvm::Intrinsic::nvvm_add_rp_ftz_f,
509 llvm::Intrinsic::nvvm_add_rz_ftz_f,
510 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f, // default rounding mode RN
511 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f,
512 llvm::Intrinsic::nvvm_add_rm_ftz_sat_f,
513 llvm::Intrinsic::nvvm_add_rp_ftz_sat_f,
514 llvm::Intrinsic::nvvm_add_rz_ftz_sat_f,
515 };
516
517 static constexpr llvm::Intrinsic::ID f64IDs[] = {
518 llvm::Intrinsic::nvvm_add_rn_d, // default rounding mode RN
519 llvm::Intrinsic::nvvm_add_rn_d, llvm::Intrinsic::nvvm_add_rm_d,
520 llvm::Intrinsic::nvvm_add_rp_d, llvm::Intrinsic::nvvm_add_rz_d};
521
522 auto addIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
523 return createScalarizedIntrinsicCall(builder, IID, opTypeLLVM,
524 {argLHS, argRHS}, opTypeLLVM);
525 };
526
527 // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
528 // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
529 // intrinsics are available.
530 if (opTypeLLVM->getScalarType()->isHalfTy()) {
531 llvm::Value *result;
532 if (isSat) {
533 unsigned index = (isVectorOp << 1) | isFTZ;
534 result = addIntrinsic(f16IDs[index]);
535 } else {
536 result = builder.CreateFAdd(argLHS, argRHS);
537 }
538 mt.mapValue(res, result);
539 return;
540 }
541
542 // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
543 if (opTypeLLVM->getScalarType()->isBFloatTy()) {
544 mt.mapValue(res, builder.CreateFAdd(argLHS, argRHS));
545 return;
546 }
547
548 // f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64>
549 if (opTypeLLVM->getScalarType()->isDoubleTy()) {
550 unsigned index = static_cast<unsigned>(rndMode);
551 mt.mapValue(res, addIntrinsic(f64IDs[index]));
552 return;
553 }
554
555 // f32 + f32 -> f32 / vector<2xf32> + vector<2xf32> -> vector<2xf32>
556 const unsigned numRndModes = 5; // NONE, RM, RN, RP, RZ
557 if (opTypeLLVM->getScalarType()->isFloatTy()) {
558 unsigned index =
559 ((isFTZ << 1) | isSat) * numRndModes + static_cast<unsigned>(rndMode);
560 mt.mapValue(res, addIntrinsic(f32IDs[index]));
561 return;
562 }
563}
564
565void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
566 llvm::IRBuilderBase &builder) {
567 auto thisOp = cast<NVVM::FmaOp>(op);
568 mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd();
569 unsigned rndIndex = static_cast<unsigned>(rndMode) - 1; // 1-4 mapped to 0-3
570 mlir::NVVM::SaturationMode satMode = thisOp.getSat();
571 bool isFTZ = thisOp.getFtz();
572 bool isRelu = thisOp.getRelu();
573 bool isSat = satMode == NVVM::SaturationMode::SAT;
574 bool isOOB = thisOp.getOob();
575
576 mlir::Type opType = thisOp.getRes().getType();
577 llvm::Type *opTypeLLVM = mt.convertType(opType);
578 bool isVectorFma = opTypeLLVM->isVectorTy();
579
580 llvm::Value *argA = mt.lookupValue(thisOp.getA());
581 llvm::Value *argB = mt.lookupValue(thisOp.getB());
582 llvm::Value *argC = mt.lookupValue(thisOp.getC());
583
584 static constexpr llvm::Intrinsic::ID f16IDs[] = {
585 llvm::Intrinsic::nvvm_fma_rn_f16,
586 llvm::Intrinsic::nvvm_fma_rn_f16x2,
587 llvm::Intrinsic::nvvm_fma_rn_ftz_f16,
588 llvm::Intrinsic::nvvm_fma_rn_ftz_f16x2,
589 llvm::Intrinsic::nvvm_fma_rn_sat_f16,
590 llvm::Intrinsic::nvvm_fma_rn_sat_f16x2,
591 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16,
592 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16x2,
593 llvm::Intrinsic::nvvm_fma_rn_relu_f16,
594 llvm::Intrinsic::nvvm_fma_rn_relu_f16x2,
595 llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16,
596 llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16x2};
597
598 static constexpr llvm::Intrinsic::ID bf16IDs[] = {
599 llvm::Intrinsic::nvvm_fma_rn_bf16, llvm::Intrinsic::nvvm_fma_rn_bf16x2,
600 llvm::Intrinsic::nvvm_fma_rn_relu_bf16,
601 llvm::Intrinsic::nvvm_fma_rn_relu_bf16x2};
602
603 static constexpr llvm::Intrinsic::ID f32IDs[] = {
604 llvm::Intrinsic::nvvm_fma_rn_f,
605 llvm::Intrinsic::nvvm_fma_rm_f,
606 llvm::Intrinsic::nvvm_fma_rp_f,
607 llvm::Intrinsic::nvvm_fma_rz_f,
608 llvm::Intrinsic::nvvm_fma_rn_sat_f,
609 llvm::Intrinsic::nvvm_fma_rm_sat_f,
610 llvm::Intrinsic::nvvm_fma_rp_sat_f,
611 llvm::Intrinsic::nvvm_fma_rz_sat_f,
612 llvm::Intrinsic::nvvm_fma_rn_ftz_f,
613 llvm::Intrinsic::nvvm_fma_rm_ftz_f,
614 llvm::Intrinsic::nvvm_fma_rp_ftz_f,
615 llvm::Intrinsic::nvvm_fma_rz_ftz_f,
616 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f,
617 llvm::Intrinsic::nvvm_fma_rm_ftz_sat_f,
618 llvm::Intrinsic::nvvm_fma_rp_ftz_sat_f,
619 llvm::Intrinsic::nvvm_fma_rz_ftz_sat_f,
620 };
621
622 static constexpr llvm::Intrinsic::ID f64IDs[] = {
623 llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d,
624 llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d};
625
626 auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID,
627 llvm::Type *retType) -> llvm::Value * {
629 builder, IID, opTypeLLVM, {argA, argB, argC}, /*retType=*/retType);
630 };
631
632 // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
633 if (opTypeLLVM->getScalarType()->isHalfTy()) {
634 llvm::Value *result;
635 if (isOOB) {
636 result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
637 : llvm::Intrinsic::nvvm_fma_rn_oob,
638 opTypeLLVM);
639 } else {
640 unsigned index =
641 (isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
642 isVectorFma; // Op verifier ensures that this index is valid
643 result = fmaIntrinsic(f16IDs[index], opTypeLLVM);
644 }
645 mt.mapValue(thisOp.getRes(), result);
646 return;
647 }
648
649 // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
650 if (opTypeLLVM->getScalarType()->isBFloatTy()) {
651 llvm::Value *result;
652 if (isOOB) {
653 result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
654 : llvm::Intrinsic::nvvm_fma_rn_oob,
655 opTypeLLVM);
656 } else {
657 unsigned index = (isRelu << 1) | isVectorFma;
658 result = fmaIntrinsic(bf16IDs[index], opTypeLLVM);
659 }
660 mt.mapValue(thisOp.getRes(), result);
661 return;
662 }
663
664 // f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64>
665 if (opTypeLLVM->getScalarType()->isDoubleTy()) {
666 mt.mapValue(thisOp.getRes(),
667 fmaIntrinsic(f64IDs[rndIndex], opTypeLLVM->getScalarType()));
668 return;
669 }
670
671 // f32 + f32 -> f32 / vector<2xf32> + vector<2xf32> -> vector<2xf32>
672 const unsigned numRndModes = 4; // RN, RM, RP, RZ
673 if (opTypeLLVM->getScalarType()->isFloatTy()) {
674 unsigned index = ((isFTZ << 1) | isSat) * numRndModes + rndIndex;
675 mt.mapValue(thisOp.getRes(),
676 fmaIntrinsic(f32IDs[index], opTypeLLVM->getScalarType()));
677 return;
678 }
679}
680
681namespace {
682/// Implementation of the dialect interface that converts operations belonging
683/// to the NVVM dialect to LLVM IR.
684class NVVMDialectLLVMIRTranslationInterface
685 : public LLVMTranslationDialectInterface {
686public:
687 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
688
689 /// Translates the given operation to LLVM IR using the provided IR builder
690 /// and saving the state in `moduleTranslation`.
691 LogicalResult
692 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
693 LLVM::ModuleTranslation &moduleTranslation) const final {
694 Operation &opInst = *op;
695#include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
696
697 return failure();
698 }
699
700 /// Attaches module-level metadata for functions marked as kernels.
701 LogicalResult
702 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
703 NamedAttribute attribute,
704 LLVM::ModuleTranslation &moduleTranslation) const final {
705 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
706 if (!func)
707 return failure();
708 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
709
710 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
711 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
712 return failure();
713 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
714 const std::string attr = llvm::formatv(
715 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
716 values.asArrayRef().end()));
717 llvmFunc->addFnAttr("nvvm.maxntid", attr);
718 } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
719 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
720 return failure();
721 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
722 const std::string attr = llvm::formatv(
723 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
724 values.asArrayRef().end()));
725 llvmFunc->addFnAttr("nvvm.reqntid", attr);
726 } else if (attribute.getName() ==
727 NVVM::NVVMDialect::getClusterDimAttrName()) {
728 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
729 return failure();
730 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
731 const std::string attr = llvm::formatv(
732 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
733 values.asArrayRef().end()));
734 llvmFunc->addFnAttr("nvvm.cluster_dim", attr);
735 } else if (attribute.getName() ==
736 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
737 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
738 llvmFunc->addFnAttr("nvvm.maxclusterrank", llvm::utostr(value.getInt()));
739 } else if (attribute.getName() ==
740 NVVM::NVVMDialect::getMinctasmAttrName()) {
741 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
742 llvmFunc->addFnAttr("nvvm.minctasm", llvm::utostr(value.getInt()));
743 } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
744 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
745 llvmFunc->addFnAttr("nvvm.maxnreg", llvm::utostr(value.getInt()));
746 } else if (attribute.getName() ==
747 NVVM::NVVMDialect::getKernelFuncAttrName()) {
748 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
749 } else if (attribute.getName() ==
750 NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
751 llvmFunc->addFnAttr("nvvm.blocksareclusters");
752 }
753
754 return success();
755 }
756
757 LogicalResult
758 convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
759 LLVM::ModuleTranslation &moduleTranslation) const final {
760
761 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
762 llvm::Function *llvmFunc =
763 moduleTranslation.lookupFunction(funcOp.getName());
764
765 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
766 llvmFunc->addParamAttr(
767 argIdx, llvm::Attribute::get(llvmContext, "nvvm.grid_constant"));
768 }
769 return success();
770 }
771};
772} // namespace
773
775 registry.insert<NVVM::NVVMDialect>();
776 registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
777 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
778 });
779}
780
782 DialectRegistry registry;
784 context.appendDialectRegistry(registry);
785}
return success()
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
static llvm::Intrinsic::ID getFenceProxyID(NVVM::ProxyKind kind, std::optional< NVVM::SharedSpace > space)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
static llvm::Intrinsic::ID getFenceProxySyncRestrictID(NVVM::MemOrderKind order)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReductionKind kind, bool hasAbs, bool hasNaN)
static llvm::Value * createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM, ArrayRef< llvm::Value * > operands, llvm::Type *retType)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry &registry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.