MLIR 23.0.0git
NVVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR NVVM dialect and
10// LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/IR/Operation.h"
18
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicsNVPTX.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/NVVMAttributes.h"
25
26using namespace mlir;
27using namespace mlir::LLVM;
29
30#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
31 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
32 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
33
34#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
35 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
36
37static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
38 NVVM::ReductionKind kind,
39 bool hasAbs, bool hasNaN) {
40 switch (kind) {
41 case NVVM::ReductionKind::ADD:
42 return llvm::Intrinsic::nvvm_redux_sync_add;
43 case NVVM::ReductionKind::UMAX:
44 return llvm::Intrinsic::nvvm_redux_sync_umax;
45 case NVVM::ReductionKind::UMIN:
46 return llvm::Intrinsic::nvvm_redux_sync_umin;
47 case NVVM::ReductionKind::AND:
48 return llvm::Intrinsic::nvvm_redux_sync_and;
49 case NVVM::ReductionKind::OR:
50 return llvm::Intrinsic::nvvm_redux_sync_or;
51 case NVVM::ReductionKind::XOR:
52 return llvm::Intrinsic::nvvm_redux_sync_xor;
53 case NVVM::ReductionKind::MAX:
54 return llvm::Intrinsic::nvvm_redux_sync_max;
55 case NVVM::ReductionKind::MIN:
56 return llvm::Intrinsic::nvvm_redux_sync_min;
57 case NVVM::ReductionKind::FMIN:
58 return GET_REDUX_F32_ID(min, hasAbs, hasNaN);
59 case NVVM::ReductionKind::FMAX:
60 return GET_REDUX_F32_ID(max, hasAbs, hasNaN);
61 }
62 llvm_unreachable("unknown reduction kind");
63}
64
65static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
66 NVVM::ShflKind kind,
67 bool withPredicate) {
68
69 if (withPredicate) {
70 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
71 switch (kind) {
72 case NVVM::ShflKind::bfly:
73 return resultType->isFloatTy()
74 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
75 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
76 case NVVM::ShflKind::up:
77 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
78 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
79 case NVVM::ShflKind::down:
80 return resultType->isFloatTy()
81 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
82 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
83 case NVVM::ShflKind::idx:
84 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
85 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
86 }
87 } else {
88 switch (kind) {
89 case NVVM::ShflKind::bfly:
90 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
91 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
92 case NVVM::ShflKind::up:
93 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
94 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
95 case NVVM::ShflKind::down:
96 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
97 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
98 case NVVM::ShflKind::idx:
99 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
100 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
101 }
102 }
103 llvm_unreachable("unknown shuffle kind");
104}
105
106static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
107 NVVM::MatchSyncKind kind) {
108 switch (kind) {
109 case NVVM::MatchSyncKind::any:
110 return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
111 : llvm::Intrinsic::nvvm_match_any_sync_i64;
112 case NVVM::MatchSyncKind::all:
113 // match.all instruction has two variants -- one returns a single value,
114 // another returns a pair {value, predicate}. We currently only implement
115 // the latter as that's the variant exposed by CUDA API.
116 return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
117 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
118 }
119 llvm_unreachable("unsupported match sync kind");
120}
121
122static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
123 switch (kind) {
124 case NVVM::VoteSyncKind::any:
125 return llvm::Intrinsic::nvvm_vote_any_sync;
126 case NVVM::VoteSyncKind::all:
127 return llvm::Intrinsic::nvvm_vote_all_sync;
128 case NVVM::VoteSyncKind::ballot:
129 return llvm::Intrinsic::nvvm_vote_ballot_sync;
130 case NVVM::VoteSyncKind::uni:
131 return llvm::Intrinsic::nvvm_vote_uni_sync;
132 }
133 llvm_unreachable("unsupported vote kind");
134}
135
136static llvm::Intrinsic::ID
137getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
138 NVVM::LdStMatrixShapeAttr shape,
139 NVVM::LdStMatrixEltType eltType) {
140 if (shape.getM() == 8 && shape.getN() == 8) {
141 switch (num) {
142 case 1:
143 return (layout == NVVM::MMALayout::row)
144 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
145 : llvm::Intrinsic::
146 nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
147 case 2:
148 return (layout == NVVM::MMALayout::row)
149 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
150 : llvm::Intrinsic::
151 nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
152 case 4:
153 return (layout == NVVM::MMALayout::row)
154 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
155 : llvm::Intrinsic::
156 nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
157 }
158 } else if (shape.getM() == 8 && shape.getN() == 16) {
159 if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
160 switch (num) {
161 case 1:
162 return llvm::Intrinsic::
163 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
164 case 2:
165 return llvm::Intrinsic::
166 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
167 case 4:
168 return llvm::Intrinsic::
169 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
170 }
171 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
172 switch (num) {
173 case 1:
174 return llvm::Intrinsic::
175 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
176 case 2:
177 return llvm::Intrinsic::
178 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
179 case 4:
180 return llvm::Intrinsic::
181 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
182 }
183 }
184 } else if (shape.getM() == 16 && shape.getN() == 16) {
185 if (eltType == NVVM::LdStMatrixEltType::B8) {
186 switch (num) {
187 case 1:
188 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
189 case 2:
190 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
191 }
192 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
193 switch (num) {
194 case 1:
195 return llvm::Intrinsic::
196 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
197 case 2:
198 return llvm::Intrinsic::
199 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
200 }
201 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
202 switch (num) {
203 case 1:
204 return llvm::Intrinsic::
205 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
206 case 2:
207 return llvm::Intrinsic::
208 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
209 }
210 }
211 }
212 llvm_unreachable("unknown ldmatrix kind");
213}
214
215/// Return the intrinsic ID associated with stmatrix for the given paramters.
216static llvm::Intrinsic::ID
217getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
218 NVVM::LdStMatrixShapeAttr shape,
219 NVVM::LdStMatrixEltType eltType) {
220 if (shape.getM() == 8 && shape.getN() == 8) {
221 switch (num) {
222 case 1:
223 return (layout == NVVM::MMALayout::row)
224 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
225 : llvm::Intrinsic::
226 nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
227 case 2:
228 return (layout == NVVM::MMALayout::row)
229 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
230 : llvm::Intrinsic::
231 nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
232 case 4:
233 return (layout == NVVM::MMALayout::row)
234 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
235 : llvm::Intrinsic::
236 nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
237 }
238 } else if (shape.getM() == 16 && shape.getN() == 8) {
239 switch (num) {
240 case 1:
241 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
242 case 2:
243 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
244 case 4:
245 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
246 }
247 }
248 llvm_unreachable("unknown stmatrix kind");
249}
250
251/// Return the intrinsic ID associated with st.bulk for the given address type.
252static llvm::Intrinsic::ID
253getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
254 bool isSharedMemory = addrType.getAddressSpace() ==
255 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
256 return isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta
257 : llvm::Intrinsic::nvvm_st_bulk;
258}
259
260static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
261 NVVM::ProxyKind toProxy,
262 NVVM::MemScopeKind scope,
263 bool isRelease) {
264 if (fromProxy == NVVM::ProxyKind::GENERIC &&
265 toProxy == NVVM::ProxyKind::TENSORMAP) {
266 switch (scope) {
267 case NVVM::MemScopeKind::CTA: {
268 if (isRelease)
269 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
270 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
271 }
272 case NVVM::MemScopeKind::CLUSTER: {
273 if (isRelease)
274 return llvm::Intrinsic::
275 nvvm_fence_proxy_tensormap_generic_release_cluster;
276 return llvm::Intrinsic::
277 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
278 }
279 case NVVM::MemScopeKind::GPU: {
280 if (isRelease)
281 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
282 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
283 }
284 case NVVM::MemScopeKind::SYS: {
285 if (isRelease)
286 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
287 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
288 }
289 }
290 llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
291 }
292 llvm_unreachable("Unsupported proxy kinds");
293}
294
295static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope) {
296 switch (scope) {
297 case NVVM::MemScopeKind::CTA:
298 return llvm::Intrinsic::nvvm_membar_cta;
299 case NVVM::MemScopeKind::CLUSTER:
300 return llvm::Intrinsic::nvvm_fence_sc_cluster;
301 case NVVM::MemScopeKind::GPU:
302 return llvm::Intrinsic::nvvm_membar_gl;
303 case NVVM::MemScopeKind::SYS:
304 return llvm::Intrinsic::nvvm_membar_sys;
305 }
306 llvm_unreachable("Unknown scope for memory barrier");
307}
308
309#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
310
311static llvm::Intrinsic::ID
312getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
313 llvm::Intrinsic::ID Shape16x64b[] = {
314 TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4),
315 TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32),
316 TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128),
317 };
318
319 llvm::Intrinsic::ID Shape16x128b[] = {
320 TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4),
321 TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32),
322 TCGEN05LD(16x128b, x64),
323 };
324
325 llvm::Intrinsic::ID Shape16x256b[] = {
326 TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4),
327 TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32),
328 };
329
330 llvm::Intrinsic::ID Shape16x32bx2[] = {
331 TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2),
332 TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8),
333 TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32),
334 TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128),
335 };
336
337 llvm::Intrinsic::ID Shape32x32b[] = {
338 TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4),
339 TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32),
340 TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128),
341 };
342
343 // `num` contains the length of vector and log2 of `num` returns the index
344 // into the shape array
345 unsigned Idx = std::log2(num);
346
347 switch (shape) {
348 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
349 return Shape16x64b[Idx];
350 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
351 return Shape16x128b[Idx - 1];
352 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
353 return Shape16x256b[Idx - 2];
354 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
355 return Shape32x32b[Idx];
356 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
357 return Shape16x32bx2[Idx];
358 }
359 llvm_unreachable("unhandled tcgen05.ld lowering");
360}
361
362#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
363
364static llvm::Intrinsic::ID
365getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
366 llvm::Intrinsic::ID Shape16x64b[] = {
367 TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4),
368 TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32),
369 TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128),
370 };
371
372 llvm::Intrinsic::ID Shape16x128b[] = {
373 TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4),
374 TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32),
375 TCGEN05ST(16x128b, x64),
376 };
377
378 llvm::Intrinsic::ID Shape16x256b[] = {
379 TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4),
380 TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32),
381 };
382
383 llvm::Intrinsic::ID Shape16x32bx2[] = {
384 TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2),
385 TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8),
386 TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32),
387 TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128),
388 };
389
390 llvm::Intrinsic::ID Shape32x32b[] = {
391 TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4),
392 TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32),
393 TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128),
394 };
395
396 // `num` contains the length of vector and log2 of `num` returns the index
397 // into the shape array
398 unsigned Idx = std::log2(num);
399
400 switch (shape) {
401 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
402 return Shape16x64b[Idx];
403 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
404 return Shape16x128b[Idx - 1];
405 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
406 return Shape16x256b[Idx - 2];
407 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
408 return Shape32x32b[Idx];
409 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
410 return Shape16x32bx2[Idx];
411 }
412 llvm_unreachable("unhandled tcgen05.st lowering");
413}
414
415static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order) {
416 return order == NVVM::MemOrderKind::ACQUIRE
417 ? llvm::Intrinsic::
418 nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster
419 : llvm::Intrinsic::
420 nvvm_fence_release_sync_restrict_space_cta_scope_cluster;
421}
422
423static llvm::Intrinsic::ID
424getFenceProxyID(NVVM::ProxyKind kind, std::optional<NVVM::SharedSpace> space) {
425 switch (kind) {
426 case NVVM::ProxyKind::alias:
427 return llvm::Intrinsic::nvvm_fence_proxy_alias;
428 case NVVM::ProxyKind::async:
429 return llvm::Intrinsic::nvvm_fence_proxy_async;
430 case NVVM::ProxyKind::async_global:
431 return llvm::Intrinsic::nvvm_fence_proxy_async_global;
432 case NVVM::ProxyKind::async_shared:
433 return *space == NVVM::SharedSpace::shared_cta
434 ? llvm::Intrinsic::nvvm_fence_proxy_async_shared_cta
435 : llvm::Intrinsic::nvvm_fence_proxy_async_shared_cluster;
436 default:
437 llvm_unreachable("unsupported proxy kind");
438 }
439}
440
441static llvm::Intrinsic::ID
442getFenceProxySyncRestrictID(NVVM::MemOrderKind order) {
443 return order == NVVM::MemOrderKind::ACQUIRE
444 ? llvm::Intrinsic::
445 nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster
446 : llvm::Intrinsic::
447 nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster;
448}
449
450// Calls an LLVM intrinsic on the given operands. For f32/f64 vector types,
451// the intrinsic is called per-element and the results are packed back into a
452// vector. If retType is non-null, it is forwarded as the return-type
453// overload to `createIntrinsicCall`.
454static llvm::Value *
455createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder,
456 llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM,
458 llvm::Type *retType) {
459 if (opTypeLLVM->isVectorTy() && (opTypeLLVM->getScalarType()->isFloatTy() ||
460 opTypeLLVM->getScalarType()->isDoubleTy())) {
461 llvm::Value *result = llvm::PoisonValue::get(
462 llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
463 for (int64_t i = 0; i < 2; ++i) {
465 for (llvm::Value *op : operands)
466 scalarArgs.push_back(
467 builder.CreateExtractElement(op, builder.getInt32(i)));
468 llvm::Value *res = createIntrinsicCall(builder, IID, retType, scalarArgs);
469 result = builder.CreateInsertElement(result, res, builder.getInt32(i));
470 }
471 return result;
472 }
473
474 return createIntrinsicCall(builder, IID, retType, operands);
475}
476
477void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
478 Value res, NVVM::FPRoundingMode rndMode,
479 NVVM::SaturationMode satMode, bool isFTZ,
481 llvm::IRBuilderBase &builder) {
482 llvm::Type *opTypeLLVM = argLHS->getType();
483 bool isVectorOp = opTypeLLVM->isVectorTy();
484 bool isSat = satMode != NVVM::SaturationMode::NONE;
485
486 // FIXME: Add intrinsics for add.rn.ftz.f16x2 and add.rn.ftz.f16 here when
487 // they are available.
488 static constexpr llvm::Intrinsic::ID f16IDs[] = {
489 llvm::Intrinsic::nvvm_add_rn_sat_f16,
490 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f16,
491 llvm::Intrinsic::nvvm_add_rn_sat_v2f16,
492 llvm::Intrinsic::nvvm_add_rn_ftz_sat_v2f16,
493 };
494
495 static constexpr llvm::Intrinsic::ID f32IDs[] = {
496 llvm::Intrinsic::nvvm_add_rn_f, // default rounding mode RN
497 llvm::Intrinsic::nvvm_add_rn_f,
498 llvm::Intrinsic::nvvm_add_rm_f,
499 llvm::Intrinsic::nvvm_add_rp_f,
500 llvm::Intrinsic::nvvm_add_rz_f,
501 llvm::Intrinsic::nvvm_add_rn_sat_f, // default rounding mode RN
502 llvm::Intrinsic::nvvm_add_rn_sat_f,
503 llvm::Intrinsic::nvvm_add_rm_sat_f,
504 llvm::Intrinsic::nvvm_add_rp_sat_f,
505 llvm::Intrinsic::nvvm_add_rz_sat_f,
506 llvm::Intrinsic::nvvm_add_rn_ftz_f, // default rounding mode RN
507 llvm::Intrinsic::nvvm_add_rn_ftz_f,
508 llvm::Intrinsic::nvvm_add_rm_ftz_f,
509 llvm::Intrinsic::nvvm_add_rp_ftz_f,
510 llvm::Intrinsic::nvvm_add_rz_ftz_f,
511 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f, // default rounding mode RN
512 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f,
513 llvm::Intrinsic::nvvm_add_rm_ftz_sat_f,
514 llvm::Intrinsic::nvvm_add_rp_ftz_sat_f,
515 llvm::Intrinsic::nvvm_add_rz_ftz_sat_f,
516 };
517
518 static constexpr llvm::Intrinsic::ID f64IDs[] = {
519 llvm::Intrinsic::nvvm_add_rn_d, // default rounding mode RN
520 llvm::Intrinsic::nvvm_add_rn_d, llvm::Intrinsic::nvvm_add_rm_d,
521 llvm::Intrinsic::nvvm_add_rp_d, llvm::Intrinsic::nvvm_add_rz_d};
522
523 auto addIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
524 return createScalarizedIntrinsicCall(builder, IID, opTypeLLVM,
525 {argLHS, argRHS}, opTypeLLVM);
526 };
527
528 // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
529 // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the
530 // intrinsics are available.
531 if (opTypeLLVM->getScalarType()->isHalfTy()) {
532 llvm::Value *result;
533 if (isSat) {
534 unsigned index = (isVectorOp << 1) | isFTZ;
535 result = addIntrinsic(f16IDs[index]);
536 } else {
537 result = builder.CreateFAdd(argLHS, argRHS);
538 }
539 mt.mapValue(res, result);
540 return;
541 }
542
543 // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
544 if (opTypeLLVM->getScalarType()->isBFloatTy()) {
545 mt.mapValue(res, builder.CreateFAdd(argLHS, argRHS));
546 return;
547 }
548
549 // f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64>
550 if (opTypeLLVM->getScalarType()->isDoubleTy()) {
551 unsigned index = static_cast<unsigned>(rndMode);
552 mt.mapValue(res, addIntrinsic(f64IDs[index]));
553 return;
554 }
555
556 // f32 + f32 -> f32 / vector<2xf32> + vector<2xf32> -> vector<2xf32>
557 const unsigned numRndModes = 5; // NONE, RM, RN, RP, RZ
558 if (opTypeLLVM->getScalarType()->isFloatTy()) {
559 unsigned index =
560 ((isFTZ << 1) | isSat) * numRndModes + static_cast<unsigned>(rndMode);
561 mt.mapValue(res, addIntrinsic(f32IDs[index]));
562 return;
563 }
564}
565
566void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt,
567 llvm::IRBuilderBase &builder) {
568 auto thisOp = cast<NVVM::FmaOp>(op);
569 mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd();
570 unsigned rndIndex = static_cast<unsigned>(rndMode) - 1; // 1-4 mapped to 0-3
571 mlir::NVVM::SaturationMode satMode = thisOp.getSat();
572 bool isFTZ = thisOp.getFtz();
573 bool isRelu = thisOp.getRelu();
574 bool isSat = satMode == NVVM::SaturationMode::SAT;
575 bool isOOB = thisOp.getOob();
576
577 mlir::Type opType = thisOp.getRes().getType();
578 llvm::Type *opTypeLLVM = mt.convertType(opType);
579 bool isVectorFma = opTypeLLVM->isVectorTy();
580
581 llvm::Value *argA = mt.lookupValue(thisOp.getA());
582 llvm::Value *argB = mt.lookupValue(thisOp.getB());
583 llvm::Value *argC = mt.lookupValue(thisOp.getC());
584
585 static constexpr llvm::Intrinsic::ID f16IDs[] = {
586 llvm::Intrinsic::nvvm_fma_rn_f16,
587 llvm::Intrinsic::nvvm_fma_rn_f16x2,
588 llvm::Intrinsic::nvvm_fma_rn_ftz_f16,
589 llvm::Intrinsic::nvvm_fma_rn_ftz_f16x2,
590 llvm::Intrinsic::nvvm_fma_rn_sat_f16,
591 llvm::Intrinsic::nvvm_fma_rn_sat_f16x2,
592 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16,
593 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16x2,
594 llvm::Intrinsic::nvvm_fma_rn_relu_f16,
595 llvm::Intrinsic::nvvm_fma_rn_relu_f16x2,
596 llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16,
597 llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16x2};
598
599 static constexpr llvm::Intrinsic::ID bf16IDs[] = {
600 llvm::Intrinsic::nvvm_fma_rn_bf16, llvm::Intrinsic::nvvm_fma_rn_bf16x2,
601 llvm::Intrinsic::nvvm_fma_rn_relu_bf16,
602 llvm::Intrinsic::nvvm_fma_rn_relu_bf16x2};
603
604 static constexpr llvm::Intrinsic::ID f32IDs[] = {
605 llvm::Intrinsic::nvvm_fma_rn_f,
606 llvm::Intrinsic::nvvm_fma_rm_f,
607 llvm::Intrinsic::nvvm_fma_rp_f,
608 llvm::Intrinsic::nvvm_fma_rz_f,
609 llvm::Intrinsic::nvvm_fma_rn_sat_f,
610 llvm::Intrinsic::nvvm_fma_rm_sat_f,
611 llvm::Intrinsic::nvvm_fma_rp_sat_f,
612 llvm::Intrinsic::nvvm_fma_rz_sat_f,
613 llvm::Intrinsic::nvvm_fma_rn_ftz_f,
614 llvm::Intrinsic::nvvm_fma_rm_ftz_f,
615 llvm::Intrinsic::nvvm_fma_rp_ftz_f,
616 llvm::Intrinsic::nvvm_fma_rz_ftz_f,
617 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f,
618 llvm::Intrinsic::nvvm_fma_rm_ftz_sat_f,
619 llvm::Intrinsic::nvvm_fma_rp_ftz_sat_f,
620 llvm::Intrinsic::nvvm_fma_rz_ftz_sat_f,
621 };
622
623 static constexpr llvm::Intrinsic::ID f64IDs[] = {
624 llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d,
625 llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d};
626
627 auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID,
628 llvm::Type *retType) -> llvm::Value * {
630 builder, IID, opTypeLLVM, {argA, argB, argC}, /*retType=*/retType);
631 };
632
633 // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16>
634 if (opTypeLLVM->getScalarType()->isHalfTy()) {
635 llvm::Value *result;
636 if (isOOB) {
637 result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
638 : llvm::Intrinsic::nvvm_fma_rn_oob,
639 opTypeLLVM);
640 } else {
641 unsigned index =
642 (isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
643 isVectorFma; // Op verifier ensures that this index is valid
644 result = fmaIntrinsic(f16IDs[index], opTypeLLVM);
645 }
646 mt.mapValue(thisOp.getRes(), result);
647 return;
648 }
649
650 // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16>
651 if (opTypeLLVM->getScalarType()->isBFloatTy()) {
652 llvm::Value *result;
653 if (isOOB) {
654 result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
655 : llvm::Intrinsic::nvvm_fma_rn_oob,
656 opTypeLLVM);
657 } else {
658 unsigned index = (isRelu << 1) | isVectorFma;
659 result = fmaIntrinsic(bf16IDs[index], opTypeLLVM);
660 }
661 mt.mapValue(thisOp.getRes(), result);
662 return;
663 }
664
665 // f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64>
666 if (opTypeLLVM->getScalarType()->isDoubleTy()) {
667 mt.mapValue(thisOp.getRes(),
668 fmaIntrinsic(f64IDs[rndIndex], opTypeLLVM->getScalarType()));
669 return;
670 }
671
672 // f32 + f32 -> f32 / vector<2xf32> + vector<2xf32> -> vector<2xf32>
673 const unsigned numRndModes = 4; // RN, RM, RP, RZ
674 if (opTypeLLVM->getScalarType()->isFloatTy()) {
675 unsigned index = ((isFTZ << 1) | isSat) * numRndModes + rndIndex;
676 mt.mapValue(thisOp.getRes(),
677 fmaIntrinsic(f32IDs[index], opTypeLLVM->getScalarType()));
678 return;
679 }
680}
681
682namespace {
683/// Implementation of the dialect interface that converts operations belonging
684/// to the NVVM dialect to LLVM IR.
685class NVVMDialectLLVMIRTranslationInterface
686 : public LLVMTranslationDialectInterface {
687public:
688 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
689
690 /// Translates the given operation to LLVM IR using the provided IR builder
691 /// and saving the state in `moduleTranslation`.
692 LogicalResult
693 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
694 LLVM::ModuleTranslation &moduleTranslation) const final {
695 // All NVVM ops are instruction-level and require an active insertion point.
696 // A null insert block means the op is misplaced (e.g., at module scope),
697 // which would otherwise cause a null dereference in createIntrinsicCall.
698 if (!builder.GetInsertBlock())
699 return op->emitOpError(
700 "cannot be translated to LLVM IR without an active insertion "
701 "point; make sure the op is inside a function");
702 Operation &opInst = *op;
703#include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
704
705 return failure();
706 }
707
708 /// Attaches module-level metadata for functions marked as kernels.
709 LogicalResult
710 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
711 NamedAttribute attribute,
712 LLVM::ModuleTranslation &moduleTranslation) const final {
713 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
714 if (!func)
715 return failure();
716 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
717
718 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
719 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
720 return failure();
721 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
722 const std::string attr = llvm::formatv(
723 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
724 values.asArrayRef().end()));
725 llvmFunc->addFnAttr(llvm::NVVMAttr::MaxNTID, attr);
726 } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
727 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
728 return failure();
729 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
730 const std::string attr = llvm::formatv(
731 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
732 values.asArrayRef().end()));
733 llvmFunc->addFnAttr(llvm::NVVMAttr::ReqNTID, attr);
734 } else if (attribute.getName() ==
735 NVVM::NVVMDialect::getClusterDimAttrName()) {
736 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
737 return failure();
738 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
739 const std::string attr = llvm::formatv(
740 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
741 values.asArrayRef().end()));
742 llvmFunc->addFnAttr(llvm::NVVMAttr::ClusterDim, attr);
743 } else if (attribute.getName() ==
744 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
745 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
746 llvmFunc->addFnAttr(llvm::NVVMAttr::MaxClusterRank,
747 llvm::utostr(value.getInt()));
748 } else if (attribute.getName() ==
749 NVVM::NVVMDialect::getMinctasmAttrName()) {
750 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
751 llvmFunc->addFnAttr(llvm::NVVMAttr::MinCTASm,
752 llvm::utostr(value.getInt()));
753 } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
754 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
755 llvmFunc->addFnAttr(llvm::NVVMAttr::MaxNReg,
756 llvm::utostr(value.getInt()));
757 } else if (attribute.getName() ==
758 NVVM::NVVMDialect::getKernelFuncAttrName()) {
759 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
760 } else if (attribute.getName() ==
761 NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
762 llvmFunc->addFnAttr(llvm::NVVMAttr::BlocksAreClusters);
763 }
764
765 return success();
766 }
767
768 LogicalResult
769 convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
770 LLVM::ModuleTranslation &moduleTranslation) const final {
771
772 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
773 llvm::Function *llvmFunc =
774 moduleTranslation.lookupFunction(funcOp.getName());
775
776 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
777 llvmFunc->addParamAttr(
778 argIdx,
779 llvm::Attribute::get(llvmContext, llvm::NVVMAttr::GridConstant));
780 }
781 return success();
782 }
783};
784} // namespace
785
787 registry.insert<NVVM::NVVMDialect>();
788 registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
789 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
790 });
791}
792
794 DialectRegistry registry;
796 context.appendDialectRegistry(registry);
797}
return success()
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
static llvm::Intrinsic::ID getFenceProxyID(NVVM::ProxyKind kind, std::optional< NVVM::SharedSpace > space)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
static llvm::Intrinsic::ID getFenceProxySyncRestrictID(NVVM::MemOrderKind order)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReductionKind kind, bool hasAbs, bool hasNaN)
static llvm::Value * createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM, ArrayRef< llvm::Value * > operands, llvm::Type *retType)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry &registry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.