MLIR 22.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/FormatVariadic.h"
36#include "llvm/Support/NVPTXAddrSpace.h"
37#include "llvm/Support/raw_ostream.h"
38#include <cassert>
39#include <optional>
40#include <string>
41
42using namespace mlir;
43using namespace NVVM;
44
45#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
47
48static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
49
50//===----------------------------------------------------------------------===//
51// Helper/Utility methods
52//===----------------------------------------------------------------------===//
53
54static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
55 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
56 return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
57}
58
60 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
61}
62
63//===----------------------------------------------------------------------===//
64// Verifier methods
65//===----------------------------------------------------------------------===//
66
67// This verifier is shared among the following Ops:
68// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
69// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
70static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
71 bool isIm2Col,
72 size_t numIm2ColOffsets,
73 Location loc) {
74 if (tensorDims < 1 || tensorDims > 5)
75 return emitError(loc, "expects coordinates between 1 to 5 dimension");
76
77 // For Im2Col mode, there are two constraints:
78 if (isIm2Col) {
79 // 1. Tensor must always be at least 3-d.
80 if (tensorDims < 3)
81 return emitError(
82 loc,
83 "to use im2col mode, the tensor has to be at least 3-dimensional");
84 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
85 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
86 return emitError(
87 loc, "im2col offsets must be 2 less than number of coordinates");
88 }
89 return success();
90}
91
92LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
93 TMAStoreMode mode = getMode();
94 // We lower through inline-ptx when getPredicate() is true.
95 // a) Only TILE mode is supported
96 // b) Cache-hint is not supported
97 if (getPredicate()) {
98 if (mode != TMAStoreMode::TILE)
99 return emitError("Inline-ptx lowering supported only for Tile mode.");
100 if (getL2CacheHint())
101 return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
102 }
103
104 size_t dims = getCoordinates().size();
105 switch (mode) {
106 case TMAStoreMode::TILE:
107 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
108 case TMAStoreMode::IM2COL:
109 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
110 case TMAStoreMode::TILE_SCATTER4:
111 if (dims != 5)
112 return emitError("Scatter4 mode expects 5 coordinates");
113 }
114 return success();
115}
116
117LogicalResult CpAsyncOp::verify() {
118 if (getModifier() != LoadCacheModifierKind::CG &&
119 getModifier() != LoadCacheModifierKind::CA)
120 return emitError("Only CG and CA cache modifiers are supported.");
121 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
122 return emitError("expected byte size to be either 4, 8 or 16.");
123 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
124 return emitError("CG cache modifier is only support for 16 bytes copy.");
125 return success();
126}
127
128// This verify params can be shared across TMA Load and Prefetch Ops.
129static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
130 TMALoadMode mode, Location loc) {
131 if (tensorDims < 1 || tensorDims > 5)
132 return emitError(loc, "expects coordinates between 1 to 5 dimension");
133
134 auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
135 size_t expectedIm2colOff) -> LogicalResult {
136 if (isIm2col && (tensorDims < 3))
137 return emitError(loc)
138 << "to use " << stringifyEnum(mode)
139 << " mode, the tensor has to be at least 3-dimensional";
140
141 if (numIm2colOff != expectedIm2colOff)
142 return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
143 << " (provided " << numIm2colOff << ")";
144
145 return success();
146 };
147
148 switch (mode) {
149 case TMALoadMode::TILE:
150 return checkTMALoadParams(mode, false, 0);
151 case TMALoadMode::IM2COL:
152 return checkTMALoadParams(mode, true, tensorDims - 2);
153 case TMALoadMode::IM2COL_W:
154 case TMALoadMode::IM2COL_W_128:
155 return checkTMALoadParams(mode, true, 2);
156 case TMALoadMode::TILE_GATHER4:
157 return (tensorDims == 5)
158 ? checkTMALoadParams(mode, false, 0)
159 : emitError(loc, "Gather4 mode expects 5 coordinates");
160 }
161 return success();
162}
163
164LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
165 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
166 getMode(), getLoc());
167}
168
169LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
170 TMALoadMode mode = getMode();
171 bool isCTAOnly = getIsCTAOnly();
172 if (getPredicate()) { // Inline-asm based lowering
173 if (isCTAOnly)
174 return emitError("Predicate is supported only for shared::cluster mode.");
175 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
176 return emitError(
177 "Predicate is supported only for Tile and Im2col modes.");
178 } else { // Intrinsics-based lowering
179 NVVMMemorySpace expectedAS =
180 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
181 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
182 .getAddressSpace();
183 if (AS != expectedAS)
184 return emitError()
185 << (isCTAOnly
186 ? "Shared::cta destination requires address-space 3."
187 : "Shared::cluster destination requires address-space 7.");
188 // Checks specific to shared::cta mode
189 if (isCTAOnly) {
190 if (getMulticastMask())
191 return emitError("Multicast is not supported with shared::cta mode.");
192 if (getGroup())
193 return emitError("CTAGroup is not supported with shared::cta mode.");
194 }
195 }
196
197 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
198 getMode(), getLoc());
199}
200
201LogicalResult CpAsyncBulkTensorReduceOp::verify() {
202 TMAStoreMode mode = getMode();
203 size_t dims = getCoordinates().size();
204 switch (mode) {
205 case TMAStoreMode::TILE:
206 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
207 case TMAStoreMode::IM2COL:
208 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
209 case TMAStoreMode::TILE_SCATTER4:
210 return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
211 }
212 return success();
213}
214
215LogicalResult ConvertFloatToTF32Op::verify() {
216 using RndMode = NVVM::FPRoundingMode;
217 switch (getRnd()) {
218 case RndMode::RNA:
219 if (getRelu())
220 return emitError("Relu not supported with rna rounding mode.");
221 break;
222 case RndMode::RN:
223 case RndMode::RZ:
224 break;
225 default:
226 return emitError(
227 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
228 }
229 return success();
230}
231
232LogicalResult ConvertF32x2ToF6x2Op::verify() {
234
235 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
236 return emitOpError("Only ")
237 << mlir::Float6E2M3FNType::get(ctx) << " and "
238 << mlir::Float6E3M2FNType::get(ctx)
239 << " types are supported for conversions from f32x2 to f6x2.";
240 }
241 return success();
242}
243
244LogicalResult ConvertF32x2ToF8x2Op::verify() {
245 using RndMode = NVVM::FPRoundingMode;
246 using SatMode = NVVM::SaturationMode;
247
248 bool isRoundingModeRN = getRnd() == RndMode::RN;
249 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
250 bool isRoundingModeRP = getRnd() == RndMode::RP;
251 bool isSatFinite = getSat() == SatMode::SATFINITE;
252
253 bool hasRelu = getRelu();
254
256
258 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
259 [&](mlir::Type) -> LogicalResult {
260 if (!isRoundingModeRN) {
261 return emitOpError("Only RN rounding mode is supported for "
262 "conversions from f32x2 to ")
263 << mlir::Float8E4M3FNType::get(ctx) << " and "
264 << mlir::Float8E5M2Type::get(ctx) << " types";
265 }
266 if (!isSatFinite) {
267 return emitOpError("Only SATFINITE saturation mode is supported "
268 "for conversions "
269 "from f32x2 to ")
270 << mlir::Float8E4M3FNType::get(ctx) << " and "
271 << mlir::Float8E5M2Type::get(ctx) << " types";
272 }
273 return success();
274 })
275 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
276 if (!(isRoundingModeRZ || isRoundingModeRP)) {
277 return emitOpError("Only RZ and RP rounding modes are supported for "
278 "conversions from f32x2 to ")
279 << mlir::Float8E8M0FNUType::get(ctx) << " type";
280 }
281 if (hasRelu) {
282 return emitOpError("relu not supported for conversions to ")
283 << mlir::Float8E8M0FNUType::get(ctx) << " type";
284 }
285 return success();
286 })
287 .Default([&](mlir::Type) {
288 return emitOpError("Only ")
289 << mlir::Float8E4M3FNType::get(ctx) << ", "
290 << mlir::Float8E5M2Type::get(ctx) << ", and "
291 << mlir::Float8E8M0FNUType::get(ctx)
292 << " types are "
293 "supported for conversions from f32x2 to f8x2";
294 });
295}
296
297LogicalResult ConvertF16x2ToF8x2Op::verify() {
299
300 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
301 return emitOpError("Only ")
302 << mlir::Float8E4M3FNType::get(ctx) << " and "
303 << mlir::Float8E5M2Type::get(ctx)
304 << " types are supported for conversions from f16x2 to f8x2.";
305 }
306 return success();
307}
308
309LogicalResult ConvertBF16x2ToF8x2Op::verify() {
310 using RndMode = NVVM::FPRoundingMode;
311
312 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
313 return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
314 << " type is supported for conversions from "
315 "bf16x2 to f8x2.";
316
317 auto rnd = getRnd();
318 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
319 return emitOpError("Only RZ and RP rounding modes are supported for "
320 "conversions from bf16x2 to f8x2.");
321
322 return success();
323}
324
325LogicalResult ConvertF32x2ToF4x2Op::verify() {
327
328 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
329 return emitOpError("Only ")
330 << mlir::Float4E2M1FNType::get(ctx)
331 << " type is supported for conversions from f32x2 to f4x2.";
332
333 return success();
334}
335
336LogicalResult ConvertF8x2ToF16x2Op::verify() {
338
339 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
340 return emitOpError("Only ")
341 << mlir::Float8E4M3FNType::get(ctx) << " and "
342 << mlir::Float8E5M2Type::get(ctx)
343 << " types are supported for conversions from f8x2 to f16x2.";
344
345 return success();
346}
347
348LogicalResult ConvertF8x2ToBF16x2Op::verify() {
350 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
351 return emitOpError("Only ")
352 << mlir::Float8E8M0FNUType::get(ctx)
353 << " type is supported for conversions from f8x2 to bf16x2.";
354
355 return success();
356}
357
358LogicalResult ConvertF6x2ToF16x2Op::verify() {
360
361 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
362 return emitOpError("Only ")
363 << mlir::Float6E2M3FNType::get(ctx) << " and "
364 << mlir::Float6E3M2FNType::get(ctx)
365 << " types are supported for conversions from f6x2 to f16x2.";
366
367 return success();
368}
369
370LogicalResult ConvertF4x2ToF16x2Op::verify() {
372
373 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
374 return emitOpError("Only ")
375 << mlir::Float4E2M1FNType::get(ctx)
376 << " type is supported for conversions from f4x2 to f16x2.";
377
378 return success();
379}
380
381//===----------------------------------------------------------------------===//
382// Stochastic Rounding Conversion Ops
383//===----------------------------------------------------------------------===//
384
385LogicalResult ConvertF32x2ToF16x2Op::verify() {
386 if (getRnd() != FPRoundingMode::RS)
387 return emitOpError("Only RS rounding mode is supported for "
388 "conversions from f32x2 to f16x2.");
389 return success();
390}
391
392LogicalResult ConvertF32x2ToBF16x2Op::verify() {
393 if (getRnd() != FPRoundingMode::RS)
394 return emitOpError("Only RS rounding mode is supported for "
395 "conversions from f32x2 to bf16x2.");
396 return success();
397}
398
399LogicalResult ConvertF32x4ToF8x4Op::verify() {
401
402 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
403 return emitOpError("Only ")
404 << mlir::Float8E4M3FNType::get(ctx) << " and "
405 << mlir::Float8E5M2Type::get(ctx)
406 << " types are supported for conversions from f32x4 to f8x4.";
407
408 return success();
409}
410
411LogicalResult ConvertF32x4ToF6x4Op::verify() {
413
414 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
415 return emitOpError("Only ")
416 << mlir::Float6E2M3FNType::get(ctx) << " and "
417 << mlir::Float6E3M2FNType::get(ctx)
418 << " types are supported for conversions from f32x4 to f6x4.";
419
420 return success();
421}
422
423LogicalResult ConvertF32x4ToF4x4Op::verify() {
425
426 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
427 return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
428 << " type is supported for conversions from "
429 "f32x4 to f4x4.";
430
431 return success();
432}
433
434LogicalResult BulkStoreOp::verify() {
435 if (getInitVal() != 0)
436 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
437 return success();
438}
439
440LogicalResult PMEventOp::verify() {
441 auto eventId = getEventId();
442 auto maskedEventId = getMaskedEventId();
443 if (!maskedEventId && !eventId) {
444 return emitOpError() << "either `id` or `mask` must be set";
445 }
446
447 if (maskedEventId && eventId) {
448 return emitOpError() << "`id` and `mask` cannot be set at the same time";
449 }
450
451 if (eventId) {
452 if (eventId < 0 || eventId > 15) {
453 return emitOpError() << "`id` must be between 0 and 15";
454 }
455 }
456
457 return llvm::success();
458}
459
460// Given the element type of an operand and whether or not it is an accumulator,
461// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
462// operand's element type.
463std::optional<mlir::NVVM::MMATypes>
464MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
465 auto half2Type =
466 VectorType::get(2, Float16Type::get(operandElType.getContext()));
467 if (operandElType.isF64())
468 return NVVM::MMATypes::f64;
469 if (operandElType.isF16() || operandElType == half2Type)
470 return NVVM::MMATypes::f16;
471 if (operandElType.isF32() && isAccumulator)
472 return NVVM::MMATypes::f32;
473 if (operandElType.isF32() && !isAccumulator)
474 return NVVM::MMATypes::tf32;
475 if (llvm::isa<IntegerType>(operandElType)) {
476 if (isAccumulator)
477 return NVVM::MMATypes::s32;
478 return std::nullopt;
479 }
480
481 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
482 if (structType.getBody().empty())
483 return std::nullopt;
484 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
485 }
486
487 return std::nullopt;
488}
489
490static bool isInt4PtxType(MMATypes type) {
491 return (type == MMATypes::u4 || type == MMATypes::s4);
492}
493
494static bool isInt8PtxType(MMATypes type) {
495 return (type == MMATypes::u8 || type == MMATypes::s8);
496}
497
498static bool isIntegerPtxType(MMATypes type) {
499 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
500 type == MMATypes::s32;
501}
502
503MMATypes MmaOp::accumPtxType() {
504 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
505 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
506 assert(val.has_value() && "accumulator PTX type should always be inferrable");
507 return val.value();
508}
509
510MMATypes MmaOp::resultPtxType() {
511 std::optional<mlir::NVVM::MMATypes> val =
512 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
513 assert(val.has_value() && "result PTX type should always be inferrable");
514 return val.value();
515}
516
517void MmaOp::print(OpAsmPrinter &p) {
518 SmallVector<Type, 4> regTypes;
519 struct OperandFragment {
520 StringRef operandName;
521 StringRef ptxTypeAttr;
522 SmallVector<Value, 4> regs;
523 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
524 : operandName(name), ptxTypeAttr(ptxTypeName) {}
525 };
526
527 std::array<OperandFragment, 3> frags{
528 OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
529 OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
530 OperandFragment("C", "")};
531 SmallVector<StringRef, 4> ignoreAttrNames{
532 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
533
534 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
535 auto &frag = frags[fragIdx];
536 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
537 for (auto operandIdx = varOperandSpec.first;
538 operandIdx < varOperandSpec.first + varOperandSpec.second;
539 operandIdx++) {
540 frag.regs.push_back(this->getOperand(operandIdx));
541 if (operandIdx == 0) {
542 regTypes.push_back(this->getOperand(operandIdx).getType());
543 }
544 }
545 std::optional<MMATypes> inferredType =
546 inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
547 if (inferredType)
548 ignoreAttrNames.push_back(frag.ptxTypeAttr);
549 }
550
551 auto printMmaOperand = [&](const OperandFragment &frag) -> void {
552 p << " " << frag.operandName;
553 p << "[";
554 p.printOperands(frag.regs);
555 p << "] ";
556 };
557
558 for (const auto &frag : frags) {
559 printMmaOperand(frag);
560 }
561
562 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
563
564 // Print the types of the operands and result.
565 p << " : "
566 << "(";
567 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
568 frags[1].regs[0].getType(),
569 frags[2].regs[0].getType()},
570 p);
571 p << ")";
572 p.printArrowTypeList(TypeRange{this->getRes().getType()});
573}
574
575void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
576 ValueRange operandA, ValueRange operandB, ValueRange operandC,
577 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
578 std::optional<MMAIntOverflow> intOverflow,
579 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
580 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
581
582 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
583 MLIRContext *ctx = builder.getContext();
584 result.addAttribute(
585 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
586
587 result.addOperands(operandA);
588 result.addOperands(operandB);
589 result.addOperands(operandC);
590
591 if (multiplicandPtxTypes) {
592 result.addAttribute("multiplicandAPtxType",
593 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
594 result.addAttribute("multiplicandBPtxType",
595 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
596 } else {
597 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
598 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
599 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
600 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
601 }
602
603 if (multiplicandLayouts) {
604 result.addAttribute("layoutA",
605 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
606 result.addAttribute("layoutB",
607 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
608 } else {
609 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
610 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
611 }
612
613 if (intOverflow.has_value())
614 result.addAttribute("intOverflowBehavior",
615 MMAIntOverflowAttr::get(ctx, *intOverflow));
616 if (b1Op.has_value())
617 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
618
619 result.addTypes(resultType);
620 result.addAttribute(
621 MmaOp::getOperandSegmentSizeAttr(),
622 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
623 static_cast<int32_t>(operandB.size()),
624 static_cast<int32_t>(operandC.size())}));
625}
626
627// <operation> :=
628// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
629// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
630// `->` type($res)
631ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
632 struct OperandFragment {
633 std::optional<MMATypes> elemtype;
634 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
635 SmallVector<Type> regTypes;
636 };
637
638 Builder &builder = parser.getBuilder();
639 std::array<OperandFragment, 4> frags;
640
641 NamedAttrList namedAttributes;
642
643 // A helper to parse the operand segments.
644 auto parseMmaOperand = [&](StringRef operandName,
645 OperandFragment &frag) -> LogicalResult {
646 if (parser.parseKeyword(operandName).failed())
647 return failure();
648 if (parser
649 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
650 .failed())
651 return failure();
652 return success();
653 };
654
655 // Parse the operand segments.
656 if (parseMmaOperand("A", frags[0]).failed())
657 return failure();
658 if (parseMmaOperand("B", frags[1]).failed())
659 return failure();
660 if (parseMmaOperand("C", frags[2]).failed())
661 return failure();
662
663 if (parser.parseOptionalAttrDict(namedAttributes).failed())
664 return failure();
665
666 // Parse the type specification and resolve operands.
667 SmallVector<Type, 3> operandTypes;
668 if (failed(parser.parseColon()))
669 return failure();
670 if (failed(parser.parseLParen()))
671 return failure();
672 if (failed(parser.parseTypeList(operandTypes)))
673 return failure();
674 if (failed(parser.parseRParen()))
675 if (operandTypes.size() != 3)
676 return parser.emitError(
677 parser.getNameLoc(),
678 "expected one type for each operand segment but got " +
679 Twine(operandTypes.size()) + " types");
680 for (const auto &iter : llvm::enumerate(operandTypes)) {
681 auto &frag = frags[iter.index()];
682 frag.regTypes.resize(frag.regs.size(), iter.value());
683 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
684 parser.getNameLoc(), result.operands)))
685 return failure();
686 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
687 /*isAccumulator*/ iter.index() < 2);
688 }
689
690 Type resultType;
691 if (parser.parseArrow() || parser.parseType(resultType))
692 return failure();
693 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
694
695 std::array<StringRef, 2> names{"multiplicandAPtxType",
696 "multiplicandBPtxType"};
697 for (unsigned idx = 0; idx < names.size(); idx++) {
698 const auto &frag = frags[idx];
699 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
700 if (!frag.elemtype.has_value() && !attr.has_value()) {
701 return parser.emitError(
702 parser.getNameLoc(),
703 "attribute " + names[idx] +
704 " is not provided explicitly and cannot be inferred");
705 }
706 if (!attr.has_value())
707 result.addAttribute(
708 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
709 }
710
711 result.addTypes(resultType);
712 if (!namedAttributes.empty())
713 result.addAttributes(namedAttributes);
714 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
715 builder.getDenseI32ArrayAttr({
716 static_cast<int32_t>(frags[0].regs.size()),
717 static_cast<int32_t>(frags[1].regs.size()),
718 static_cast<int32_t>(frags[2].regs.size()),
719 }));
720 return success();
721}
722
723LogicalResult MmaOp::verify() {
724 MLIRContext *context = getContext();
725 auto f16Ty = Float16Type::get(context);
726 auto i32Ty = IntegerType::get(context, 32);
727 auto f16x2Ty = VectorType::get(2, f16Ty);
728 auto f32Ty = Float32Type::get(context);
729 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
730 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
731
732 auto s32x4StructTy =
733 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
734 auto f32x8StructTy =
735 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
736 auto f16x2x2StructTy =
737 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
738 auto f32x4StructTy =
739 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
740 auto s32x2StructTy =
741 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
742
743 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
744 getShapeAttr().getK()};
745
746 // These variables define the set of allowed data types for matrices A, B, C,
747 // and result.
748 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
749 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
750 AllowedShapes allowedShapes;
751 AllowedTypes expectedA;
752 AllowedTypes expectedB;
753 AllowedTypes expectedC;
754 SmallVector<Type> expectedResult;
755
756 // When M = 16, we just need to calculate the number of 8xk tiles, where
757 // k is a factor that depends on the data type.
758 if (mmaShape[0] == 16) {
759 int64_t kFactor;
760 Type multiplicandFragType;
761 switch (*getMultiplicandAPtxType()) {
762 case MMATypes::tf32:
763 kFactor = 4;
764 multiplicandFragType = i32Ty;
765 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
766 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
767 break;
768 case MMATypes::bf16:
769 kFactor = 8;
770 multiplicandFragType = i32Ty;
771 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
772 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
773 break;
774 case MMATypes::f16:
775 kFactor = 8;
776 multiplicandFragType = f16x2Ty;
777 expectedResult.push_back(f16x2x2StructTy);
778 expectedResult.push_back(f32x4StructTy);
779 break;
780 case MMATypes::s4:
781 case MMATypes::u4:
782 kFactor = 32;
783 break;
784 case MMATypes::b1:
785 kFactor = 128;
786 break;
787 case MMATypes::s8:
788 case MMATypes::u8:
789 kFactor = 16;
790 break;
791 default:
792 return emitError("invalid shape or multiplicand type: " +
793 stringifyEnum(getMultiplicandAPtxType().value()));
794 }
795
796 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
797 expectedResult.push_back(s32x4StructTy);
798 expectedC.emplace_back(4, i32Ty);
799 multiplicandFragType = i32Ty;
800 } else {
801 expectedC.emplace_back(2, f16x2Ty);
802 expectedC.emplace_back(4, f32Ty);
803 }
804
805 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
806 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
807 expectedA.emplace_back(unitA, multiplicandFragType);
808 expectedB.emplace_back(unitB, multiplicandFragType);
809 allowedShapes.push_back({16, 8, kFactor});
810 allowedShapes.push_back({16, 8, kFactor * 2});
811
812 if (resultPtxType() != accumPtxType())
813 return emitOpError("ctype does not match dtype");
814 }
815
816 // In the M=8 case, there is only 1 possible case per data type.
817 if (mmaShape[0] == 8) {
818 if (*getMultiplicandAPtxType() == MMATypes::f16) {
819 expectedA.emplace_back(2, f16x2Ty);
820 expectedB.emplace_back(2, f16x2Ty);
821 expectedResult.push_back(f16x2x4StructTy);
822 expectedResult.push_back(f32x8StructTy);
823 expectedC.emplace_back(4, f16x2Ty);
824 expectedC.emplace_back(8, f32Ty);
825 allowedShapes.push_back({8, 8, 4});
826 }
827 if (*getMultiplicandAPtxType() == MMATypes::f64) {
828 Type f64Ty = Float64Type::get(context);
829 expectedA.emplace_back(1, f64Ty);
830 expectedB.emplace_back(1, f64Ty);
831 expectedC.emplace_back(2, f64Ty);
832 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
833 context, SmallVector<Type>(2, f64Ty)));
834 allowedShapes.push_back({8, 8, 4});
835 }
836 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
837 expectedA.push_back({i32Ty});
838 expectedB.push_back({i32Ty});
839 expectedC.push_back({i32Ty, i32Ty});
840 expectedResult.push_back(s32x2StructTy);
841 if (isInt4PtxType(getMultiplicandAPtxType().value()))
842 allowedShapes.push_back({8, 8, 32});
843 if (isInt8PtxType(getMultiplicandAPtxType().value()))
844 allowedShapes.push_back({8, 8, 16});
845 if (getMultiplicandAPtxType().value() == MMATypes::b1)
846 allowedShapes.push_back({8, 8, 128});
847 }
848 }
849
850 std::string errorMessage;
851 llvm::raw_string_ostream errorStream(errorMessage);
852
853 // Check that we matched an existing shape/dtype combination.
854 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
855 !llvm::is_contained(allowedShapes, mmaShape)) {
856 errorStream << "unimplemented variant for MMA shape <";
857 llvm::interleaveComma(mmaShape, errorStream);
858 errorStream << ">";
859 return emitOpError(errorMessage);
860 }
861
862 // Verify the operand types for segments of A, B, and C operands.
863 std::array<StringRef, 3> operandNames{"A", "B", "C"};
864 for (const auto &iter : llvm::enumerate(
865 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
866 auto spec = this->getODSOperandIndexAndLength(iter.index());
867 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
868 operand_type_begin() + spec.first +
869 spec.second);
870 bool match = llvm::is_contained(iter.value(), operandTySeg);
871
872 if (!match) {
873 errorStream << "Could not match types for the "
874 << operandNames[iter.index()]
875 << " operands; expected one of ";
876 for (const auto &x : iter.value()) {
877 errorStream << x.size() << "x" << x[0] << " ";
878 }
879 errorStream << "but got ";
880 llvm::interleaveComma(operandTySeg, errorStream);
881 return emitOpError(errorMessage);
882 }
883 }
884
885 // Check the result type
886 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
887 return expectedResultType == getResult().getType();
888 })) {
889 errorStream
890 << "Could not match allowed types for the result; expected one of ";
891 llvm::interleaveComma(expectedResult, errorStream);
892 errorStream << " but got " << getResult().getType();
893 return emitOpError(errorMessage);
894 }
895
896 // Ensure that binary MMA variants have a b1 MMA operation defined.
897 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
898 return emitOpError("op requires " + getB1OpAttrName().strref() +
899 " attribute");
900 }
901
902 // Ensure int4/int8 MMA variants specify the accum overflow behavior
903 // attribute.
904 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
905 isInt8PtxType(*getMultiplicandAPtxType())) {
906 if (!getIntOverflowBehavior())
907 return emitOpError("op requires " +
908 getIntOverflowBehaviorAttrName().strref() +
909 " attribute");
910 }
911
912 // Validate layout combinations. According to the operation description, most
913 // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
914 // can use other layout combinations.
915 bool isM8N8K4_F16 =
916 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
917 getMultiplicandAPtxType() == MMATypes::f16);
918
919 if (!isM8N8K4_F16) {
920 // For all other shapes/types, layoutA must be row and layoutB must be col
921 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
922 return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
923 "layoutB = #nvvm.mma_layout<col> for shape <")
924 << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
925 << "> with element types "
926 << stringifyEnum(*getMultiplicandAPtxType()) << " and "
927 << stringifyEnum(*getMultiplicandBPtxType())
928 << ". Only m8n8k4 with f16 supports other layouts.";
929 }
930 }
931
932 return success();
933}
934
935LogicalResult ShflOp::verify() {
936 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
937
938 auto verifyTypeError = [&](Twine desc, Type expectedType,
939 Type actualType) -> LogicalResult {
940 return emitOpError("expected " + desc + " to be of type ")
941 << expectedType << " but got " << actualType << " instead";
942 };
943
944 if (returnStructType) {
945 if (!getReturnValueAndIsValid())
946 return emitOpError("\"return_value_and_is_valid\" attribute must be "
947 "specified when the return type is a struct type");
948
949 if (returnStructType.getBody().size() != 2)
950 return emitOpError("expected return type to be a two-element struct");
951
952 llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
953 auto resultType = returnStruct[0];
954 if (resultType != getVal().getType())
955 return verifyTypeError("first element in the returned struct",
956 getVal().getType(), resultType);
957
958 auto predicateType = returnStruct[1];
959 if (!predicateType.isInteger(1))
960 return verifyTypeError("second element in the returned struct",
961 mlir::IntegerType::get(getContext(), 1),
962 predicateType);
963 } else {
964 if (getReturnValueAndIsValid())
965 return emitOpError("expected return type to be a two-element struct");
966
967 if (getType() != getVal().getType())
968 return verifyTypeError("return type", getVal().getType(), getType());
969 }
970 return success();
971}
972
973std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
974 NVVM::MMAFrag frag, int nRow,
975 int nCol,
976 MLIRContext *context) {
977 unsigned numberElements = 0;
978 Type elementType;
979 OpBuilder builder(context);
980 Type f16x2 = VectorType::get(2, builder.getF16Type());
981 if (type == NVVM::MMATypes::f16) {
982 elementType = f16x2;
983 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
984 numberElements = 8;
985 else
986 numberElements = 4;
987 } else if (type == NVVM::MMATypes::f32) {
988 elementType = builder.getF32Type();
989 numberElements = 8;
990 } else if (type == NVVM::MMATypes::f64) {
991 elementType = builder.getF64Type();
992 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
993 numberElements = 1;
994 else
995 numberElements = 2;
996 } else if (type == NVVM::MMATypes::tf32) {
997 elementType = builder.getI32Type();
998 numberElements = 4;
999 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
1000 elementType = builder.getI32Type();
1001 int parallelSize = 0;
1002 if (frag == NVVM::MMAFrag::a)
1003 parallelSize = nRow;
1004 if (frag == NVVM::MMAFrag::b)
1005 parallelSize = nCol;
1006
1007 // m == 16 && n == 16 && k == 16
1008 if (parallelSize == 16)
1009 numberElements = 2;
1010 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
1011 else if (parallelSize == 8)
1012 numberElements = 1;
1013 else if (parallelSize == 32)
1014 numberElements = 4;
1015 } else if (type == NVVM::MMATypes::s32) {
1016 elementType = builder.getI32Type();
1017 numberElements = 8;
1018 }
1019 assert(numberElements != 0 && elementType != nullptr);
1020 return std::make_pair(elementType, numberElements);
1021}
1022
1023static std::pair<mlir::Type, unsigned>
1024inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
1025 int k, MLIRContext *context) {
1026 int nRow, nCol;
1027 if (frag == NVVM::MMAFrag::a) {
1028 nRow = m;
1029 nCol = k;
1030 } else if (frag == NVVM::MMAFrag::b) {
1031 nRow = k;
1032 nCol = n;
1033 } else {
1034 nRow = m;
1035 nCol = n;
1036 }
1037 assert(nRow && nCol);
1038 return inferMMAType(type, frag, nRow, nCol, context);
1039}
1040
1041LogicalResult NVVM::WMMALoadOp::verify() {
1042 unsigned addressSpace =
1043 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
1044 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
1045 addressSpace != NVVMMemorySpace::Shared)
1046 return emitOpError("expected source pointer in memory "
1047 "space 0, 1, 3");
1048
1049 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
1050 getEltype(), getFrag()) == 0)
1051 return emitOpError() << "invalid attribute combination";
1052 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
1053 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
1054 // Special case for f64 fragments
1055 Type f64Ty = Float64Type::get(getContext());
1056 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
1057 if (getType() != f64Ty)
1058 return emitOpError("expected destination type to be f64");
1059 return success();
1060 }
1061 // Everything else is a struct
1062 Type dstType = LLVM::LLVMStructType::getLiteral(
1063 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
1064 if (getType() != dstType)
1065 return emitOpError("expected destination type is a structure of ")
1066 << typeInfo.second << " elements of type " << typeInfo.first;
1067 return success();
1068}
1069
1070LogicalResult NVVM::WMMAStoreOp::verify() {
1071 unsigned addressSpace =
1072 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
1073 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
1074 addressSpace != NVVMMemorySpace::Shared)
1075 return emitOpError("expected operands to be a source pointer in memory "
1076 "space 0, 1, 3");
1077
1078 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
1079 getEltype()) == 0)
1080 return emitOpError() << "invalid attribute combination";
1081 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
1082 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
1083 if (getArgs().size() != typeInfo.second)
1084 return emitOpError() << "expected " << typeInfo.second << " data operands";
1085 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
1086 return operands.getType() != typeInfo.first;
1087 }))
1088 return emitOpError() << "expected data operands of type " << typeInfo.first;
1089 return success();
1090}
1091
1092LogicalResult NVVM::WMMAMmaOp::verify() {
1093 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
1094 getLayoutB(), getEltypeA(),
1095 getEltypeB()) == 0)
1096 return emitOpError() << "invalid attribute combination";
1097 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
1098 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
1099 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
1100 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
1101 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
1102 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
1103 SmallVector<Type, 32> arguments;
1104 arguments.append(typeInfoA.second, typeInfoA.first);
1105 arguments.append(typeInfoB.second, typeInfoB.first);
1106 arguments.append(typeInfoC.second, typeInfoC.first);
1107 unsigned numArgs = arguments.size();
1108 if (getArgs().size() != numArgs)
1109 return emitOpError() << "expected " << numArgs << " arguments";
1110 for (unsigned i = 0; i < numArgs; i++) {
1111 if (getArgs()[i].getType() != arguments[i])
1112 return emitOpError() << "expected argument " << i << " to be of type "
1113 << arguments[i];
1114 }
1115 Type dstType = LLVM::LLVMStructType::getLiteral(
1116 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
1117 if (getType() != dstType)
1118 return emitOpError("expected destination type is a structure of ")
1119 << typeInfoC.second << " elements of type " << typeInfoC.first;
1120 return success();
1121}
1122
1123LogicalResult NVVM::LdMatrixOp::verify() {
1124 uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
1125 if (m == 8 && n == 8) {
1126 if (num != 1 && num != 2 && num != 4) {
1127 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
1128 "matrix");
1129 }
1130 if (getEltType() != LdStMatrixEltType::B16) {
1131 return emitOpError("expected element type to be b16 for 8x8 matrix");
1132 }
1133 } else if (m == 8 && n == 16) {
1134 if (num != 1 && num != 2 && num != 4) {
1135 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
1136 "matrix");
1137 }
1138 if (getLayout() != MMALayout::row) {
1139 return emitOpError("expected layout to be row for 8x16 matrix");
1140 }
1141 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1142 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1143 return emitOpError("expected element type to be b8x16.b4x16_p64 or "
1144 "b8x16.b6x16_p32 for 8x16 matrix");
1145 }
1146 } else if (m == 16 && n == 16) {
1147 if (num != 1 && num != 2) {
1148 return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
1149 "matrix");
1150 }
1151 if (getLayout() != MMALayout::col) {
1152 return emitOpError("expected layout to be col for 16x16 matrix");
1153 }
1154 if (getEltType() != LdStMatrixEltType::B8 &&
1155 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1156 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1157 return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
1158 "b8x16.b6x16_p32 for 16x16 matrix");
1159 }
1160 } else {
1161 return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
1162 }
1163
1164 Type i32 = IntegerType::get(getContext(), 32);
1165 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
1166 if (numElements == 1 && getType() != i32)
1167 return emitOpError("expected destination type is i32");
1168 if (numElements == 2 || numElements == 4) {
1169 Type dstType = LLVM::LLVMStructType::getLiteral(
1170 getContext(), SmallVector<Type>(numElements, i32));
1171 if (getType() != dstType)
1172 return emitOpError("expected destination type is a structure of ")
1173 << numElements << " elements of type i32";
1174 }
1175
1176 return success();
1177}
1178
1179LogicalResult NVVM::StMatrixOp::verify() {
1180 int numMatrix = getSources().size();
1181 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
1182 return emitOpError("expected num attribute to be 1, 2 or 4");
1183
1184 int m = getShape().getM(), n = getShape().getN();
1185 if (m == 8 && n == 8) {
1186 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
1187 return emitOpError("expected element type to be B16 for 8x8 matrix");
1188 }
1189 } else if (m == 16 && n == 8) {
1190 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
1191 return emitOpError("expected element type to be B8 for 16x8 matrix");
1192 }
1193 if (getLayout() != NVVM::MMALayout::col) {
1194 return emitOpError("expected layout to be col for 16x8 matrix");
1195 }
1196 } else {
1197 return emitOpError("expected shape to be 8x8 or 16x8");
1198 }
1199
1200 return success();
1201}
1202
1203static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
1204 if (typeA == NVVM::WGMMATypes::tf32)
1205 return 8;
1206 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
1207 return 16;
1208 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
1209 return 32;
1210 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
1211 return 32;
1212 if (typeA == NVVM::WGMMATypes::b1)
1213 return 256;
1214 return failure();
1215}
1216
1217static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
1218 NVVM::WGMMATypes typeA,
1219 NVVM::WGMMATypes typeB) {
1220 switch (typeA) {
1221 case NVVM::WGMMATypes::f16:
1222 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1223 typeB == NVVM::WGMMATypes::f16)
1224 return success();
1225 break;
1226 case NVVM::WGMMATypes::tf32:
1227 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
1228 return success();
1229 break;
1230 case NVVM::WGMMATypes::u8:
1231 case NVVM::WGMMATypes::s8:
1232 if (typeD == NVVM::WGMMATypes::s32 &&
1233 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
1234 return success();
1235 break;
1236 case NVVM::WGMMATypes::b1:
1237 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
1238 return success();
1239 break;
1240 case NVVM::WGMMATypes::bf16:
1241 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1242 typeB == NVVM::WGMMATypes::bf16)
1243 return success();
1244 break;
1245 case NVVM::WGMMATypes::e4m3:
1246 case NVVM::WGMMATypes::e5m2:
1247 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1248 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1249 return success();
1250 break;
1251 case WGMMATypes::f32:
1252 case WGMMATypes::s32:
1253 llvm_unreachable("unsupported input types");
1254 break;
1255 }
1256 return failure();
1257}
1258
1259static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
1260 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
1261 72, 80, 88, 96, 104, 112, 120, 128,
1262 136, 144, 152, 160, 168, 176, 184, 192,
1263 200, 208, 216, 224, 232, 240, 248, 256};
1264 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
1265 80, 96, 112, 128, 144, 160,
1266 176, 192, 208, 224, 240, 256};
1267 switch (typeA) {
1268 case WGMMATypes::f16:
1269 case WGMMATypes::tf32:
1270 case WGMMATypes::bf16:
1271 case WGMMATypes::e4m3:
1272 case WGMMATypes::e5m2:
1273 if (llvm::is_contained(allowedN, sizeN))
1274 return success();
1275 break;
1276 case WGMMATypes::u8:
1277 case WGMMATypes::s8:
1278 case WGMMATypes::b1:
1279 if (llvm::is_contained(allowedNshort, sizeN))
1280 return success();
1281 break;
1282 case WGMMATypes::f32:
1283 case WGMMATypes::s32:
1284 llvm_unreachable("unsupported input types");
1285 break;
1286 }
1287 return failure();
1288}
1289
1290LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
1291 Value outValue = getResults();
1292 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
1293 if (!stype)
1294 return emitOpError() << "expected results to be struct";
1295 int outputSize = stype.getBody().size();
1296 WGMMATypes typeD = getTypeD();
1297 WGMMATypes typeA = getTypeA();
1298 WGMMATypes typeB = getTypeB();
1299
1300 for (Type t : stype.getBody()) {
1301 if (t != stype.getBody().front())
1302 return emitOpError()
1303 << "all elements in struct must be same type but there is " << t;
1304 }
1305
1306 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1307 typeD != WGMMATypes::s32) {
1308 return emitOpError() << "does not support the given output type "
1309 << NVVM::stringifyWGMMATypes(typeD);
1310 }
1311 if (typeD == WGMMATypes::s32 &&
1312 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1313 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
1314 }
1315
1316 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
1317 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1318 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
1319 << NVVM::stringifyWGMMATypes(typeB)
1320 << ", it is not supported.";
1321 }
1322
1323 // Check M
1324 if (getShape().getM() != 64)
1325 return emitOpError() << "shape 'm' must be 64";
1326
1327 // Check K
1328 FailureOr<int> allowedK = getAllowedSizeK(typeA);
1329 if (failed(allowedK) || allowedK.value() != getShape().getK())
1330 return emitOpError() << "shape 'k' must be " << allowedK.value()
1331 << " for input type "
1332 << NVVM::stringifyWGMMATypes(typeA);
1333
1334 // Check N
1335 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
1336 return emitOpError() << "has input type "
1337 << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
1338 << getShape().getN() << ", it is not supported.";
1339 }
1340
1341 // Check transpose (only available for f16/bf16)
1342 // Matrices A should be stored in row-major and B in column-major.
1343 // Only f16/bf16 matrices can be stored in either column-major or row-major
1344 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
1345 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1346 (getLayoutA() == mlir::NVVM::MMALayout::col ||
1347 getLayoutB() == mlir::NVVM::MMALayout::row)) {
1348 return emitOpError()
1349 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1350 << " and layout_b = " << stringifyMMALayout(getLayoutB())
1351 << " for input types " << stringifyWGMMATypes(typeA) << " and "
1352 << stringifyWGMMATypes(typeB)
1353 << " requires transpose. However, this is only supported for: "
1354 << stringifyMMATypes(MMATypes::f16) << " and "
1355 << stringifyMMATypes(MMATypes::bf16);
1356 }
1357
1358 // Check result registers
1359 int expectedOutput = 0;
1360 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1361 expectedOutput = getShape().getN() / 2;
1362 if (typeD == WGMMATypes::f16)
1363 expectedOutput = getShape().getN() / 4;
1364 if (outputSize != expectedOutput) {
1365 return emitOpError() << "results " << expectedOutput
1366 << ", however output struct has " << outputSize
1367 << " elements";
1368 }
1369 // Check satfinite (only available for s32 accumulator)
1370 if (typeD != WGMMATypes::s32 &&
1371 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1372 NVVM::MMAIntOverflow::satfinite) {
1373 return emitOpError()
1374 << " `satfinite` can be only used with s32 accumulator, however "
1375 "the current accumulator is "
1376 << NVVM::stringifyWGMMATypes(typeD);
1377 }
1378
1379 return success();
1380}
1381
1382std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1383
1384 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1385 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1386
1387 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1388
1389 int expectedOutputRegisters = 0;
1390 if (getTypeD() == WGMMATypes::f16)
1391 expectedOutputRegisters = getShape().getN() / 4;
1392 else
1393 expectedOutputRegisters = getShape().getN() / 2;
1394
1395 std::string ptx;
1396 llvm::raw_string_ostream ss(ptx);
1397
1398 ss << "{\n"
1399 ".reg .pred p;\n"
1400 "setp.ne.b32 p, $"
1401 << ((expectedOutputRegisters * 2) + 2)
1402 << ", 0;\n"
1403 "wgmma.mma_async.sync.aligned.m"
1404 << m << "n" << n << "k" << k << "." << outputTypeName << "."
1405 << stringifyWGMMATypes(getTypeA()) << "."
1406 << stringifyWGMMATypes(getTypeB());
1407 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1408 NVVM::MMAIntOverflow::satfinite)
1409 ss << ".satfinite";
1410 ss << " {";
1411 int regCnt = 0;
1412 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1413 ss << "$" << regCnt;
1414 if (regCnt != expectedOutputRegisters - 1)
1415 ss << ", ";
1416 }
1417
1418 ss << "},";
1419 // Need to map read/write registers correctly.
1420 regCnt = (regCnt * 2);
1421 ss << " $" << (regCnt) << ","
1422 << " $" << (regCnt + 1) << ","
1423 << " p";
1424 if (getTypeD() != WGMMATypes::s32) {
1425 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1426 }
1427 // Don't add transpose parameters unless needed.
1428 if (isF16) {
1429 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1430 }
1431 ss << ";\n"
1432 << "}\n";
1433 return ptx;
1434}
1435
1436bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1437 RewriterBase &rewriter,
1438 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1439 &asmValues) {
1440 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1441 if (getResults())
1442 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1443 if (getInouts())
1444 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1445 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1446 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1447 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1449 if (getTypeD() != WGMMATypes::s32) {
1450 asmValues.push_back(
1451 {makeConstantI32(rewriter,
1452 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1454 asmValues.push_back(
1455 {makeConstantI32(rewriter,
1456 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1458 }
1459 if (isF16) {
1460 asmValues.push_back(
1461 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1463 asmValues.push_back(
1464 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1466 }
1467 return true; // Has manual mapping
1468}
1469
1470LogicalResult NVVM::FenceProxyOp::verify() {
1471 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1472 return emitOpError() << "tensormap proxy is not a supported proxy kind";
1473 if (getKind() == NVVM::ProxyKind::GENERIC)
1474 return emitOpError() << "generic proxy not a supported proxy kind";
1475 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1476 return emitOpError() << "async_shared fence requires space attribute";
1477 }
1478 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1479 return emitOpError() << "only async_shared fence can have space attribute";
1480 }
1481 return success();
1482}
1483
1484LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1485 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1486 return emitOpError("uni-directional proxies only support generic for "
1487 "from_proxy attribute");
1488
1489 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1490 return emitOpError("uni-directional proxies only support tensormap "
1491 "for to_proxy attribute");
1492
1493 return success();
1494}
1495
1496LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1497 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1498 return emitOpError("uni-directional proxies only support generic for "
1499 "from_proxy attribute");
1500
1501 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1502 return emitOpError("uni-directional proxies only support tensormap "
1503 "for to_proxy attribute");
1504
1505 return success();
1506}
1507
1508LogicalResult NVVM::SetMaxRegisterOp::verify() {
1509 if (getRegCount() % 8)
1510 return emitOpError("new register size must be multiple of 8");
1511 if (getRegCount() < 24 || getRegCount() > 256)
1512 return emitOpError("new register size must be in between 24 to 256");
1513 return success();
1514}
1515
1516LogicalResult NVVM::BarrierOp::verify() {
1517 if (getNumberOfThreads() && !getBarrierId())
1518 return emitOpError(
1519 "barrier id is missing, it should be set between 0 to 15");
1520
1521 if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
1522 return emitOpError("reduction are only available when id is 0");
1523
1524 if ((getReductionOp() && !getReductionPredicate()) ||
1525 (!getReductionOp() && getReductionPredicate()))
1526 return emitOpError("reduction predicate and reduction operation must be "
1527 "specified together");
1528
1529 return success();
1530}
1531
1532LogicalResult NVVM::Tcgen05CpOp::verify() {
1533 auto mc = getMulticast();
1534
1535 using SH = Tcgen05CpShape;
1536 using MC = Tcgen05CpMulticast;
1537 switch (getShape()) {
1538 case SH::SHAPE_128x256b:
1539 case SH::SHAPE_128x128b:
1540 case SH::SHAPE_4x256b:
1541 if (mc != MC::NONE)
1542 return emitError("Invalid multicast type for tcgen05.cp Op");
1543 break;
1544 case SH::SHAPE_64x128b:
1545 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1546 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1547 "warpx2_02_13 for tcgen05.cp Op");
1548 break;
1549 case SH::SHAPE_32x128b:
1550 if (mc != MC::WARPX4)
1551 return emitError(
1552 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1553 break;
1554 }
1555 return success();
1556}
1557
1558LogicalResult NVVM::MatchSyncOp::verify() {
1559 if (getKind() == NVVM::MatchSyncKind::all) {
1560 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1561 if (!type || type.getBody().size() != 2 ||
1562 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1563 return emitOpError("match.sync 'all' returns a two element struct with "
1564 "first element as i32 and second element as i1");
1565 }
1566 } else {
1567 if (!getType().isInteger(32)) {
1568 return emitOpError("match.sync 'any' returns an i32");
1569 }
1570 }
1571 return success();
1572}
1573
1574LogicalResult NVVM::VoteSyncOp::verify() {
1575 if (getKind() == NVVM::VoteSyncKind::ballot) {
1576 if (!getType().isInteger(32)) {
1577 return emitOpError("vote.sync 'ballot' returns an i32");
1578 }
1579 } else {
1580 if (!getType().isInteger(1)) {
1581 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1582 }
1583 }
1584 return success();
1585}
1586
1587LogicalResult NVVM::PrefetchOp::verify() {
1588 using MemSpace = NVVM::NVVMMemorySpace;
1589 using CacheLevel = NVVM::PrefetchCacheLevel;
1590
1591 unsigned addressSpace =
1592 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1593 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1594 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1595
1596 if (getTensormap() && cacheLevel)
1597 return emitOpError("cannot specify both tensormap and cache level");
1598
1599 if (getTensormap()) {
1600 if (addressSpace != MemSpace::Generic &&
1601 addressSpace != MemSpace::Constant) {
1602 return emitOpError(
1603 "prefetch tensormap requires a generic or constant pointer");
1604 }
1605
1606 if (evictPriority) {
1607 return emitOpError(
1608 "prefetch tensormap does not support eviction priority");
1609 }
1610
1611 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1612 return emitOpError(
1613 "in_param_space can only be specified for a generic pointer");
1614 }
1615
1616 } else if (cacheLevel) {
1617 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1618 addressSpace != MemSpace::Local) {
1619 return emitOpError("prefetch to cache level requires a generic, global, "
1620 "or local pointer");
1621 }
1622
1623 if (getUniform()) {
1624 if (*cacheLevel != CacheLevel::L1) {
1625 return emitOpError(
1626 "unsupported cache level, the only supported uniform "
1627 "cache level is L1");
1628 }
1629
1630 if (addressSpace != MemSpace::Generic) {
1631 return emitOpError(
1632 "prefetch to uniform cache requires a generic pointer");
1633 }
1634 }
1635
1636 if (evictPriority) {
1637 if (*cacheLevel != CacheLevel::L2)
1638 return emitOpError(
1639 "cache eviction priority supported only for cache level L2");
1640
1641 if (addressSpace != MemSpace::Global)
1642 return emitOpError("cache eviction priority requires a global pointer");
1643
1644 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1645 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1646 return emitOpError(
1647 "unsupported cache eviction priority, only evict_last and "
1648 "evict_normal are supported");
1649 }
1650
1651 if (getPredicate())
1652 return emitOpError("predicate supported only on prefetch tensormap");
1653
1654 } else {
1655 return emitOpError(
1656 "requires specification of either cache level or tensormap");
1657 }
1658
1659 return success();
1660}
1661
1662LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
1663 switch (getQueryType()) {
1664 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1665 if (!getType().isInteger(1))
1666 return emitOpError("is_canceled query type returns an i1");
1667 break;
1668 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1669 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1670 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1671 if (!getType().isInteger(32)) {
1672 return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
1673 "get_first_cta_id_z query types return an i32");
1674 }
1675 break;
1676 }
1677 return success();
1678}
1679
1680LogicalResult NVVM::ReduxOp::verify() {
1681 mlir::Type reduxType = getType();
1682
1683 if (!reduxType.isF32()) {
1684 if (getAbs())
1685 return emitOpError("abs attribute is supported only for f32 type");
1686 if (getNan())
1687 return emitOpError("nan attribute is supported only for f32 type");
1688 }
1689
1690 NVVM::ReduxKind kind = getKind();
1691 switch (kind) {
1692 case NVVM::ReduxKind::ADD:
1693 case NVVM::ReduxKind::AND:
1694 case NVVM::ReduxKind::OR:
1695 case NVVM::ReduxKind::XOR:
1696 case NVVM::ReduxKind::MAX:
1697 case NVVM::ReduxKind::MIN:
1698 case NVVM::ReduxKind::UMAX:
1699 case NVVM::ReduxKind::UMIN:
1700 if (!reduxType.isInteger(32))
1701 return emitOpError("'")
1702 << stringifyEnum(kind) << "' redux kind unsupported with "
1703 << reduxType << " type. Only supported type is 'i32'.";
1704 break;
1705 case NVVM::ReduxKind::FMIN:
1706 case NVVM::ReduxKind::FMAX:
1707 if (!reduxType.isF32())
1708 return emitOpError("'")
1709 << stringifyEnum(kind) << "' redux kind unsupported with "
1710 << reduxType << " type. Only supported type is 'f32'.";
1711 break;
1712 }
1713
1714 return success();
1715}
1716
1717/// Packs the given `field` into the `result`.
1718/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1719static llvm::Value *
1720packValInto64Bits(llvm::IRBuilderBase &builder,
1721 llvm::Value *result, // the `result` (unset bits are zero)
1722 llvm::Value *field, // `field` to pack into `result`
1723 unsigned sizeInBits, // Size of `field` in bits
1724 unsigned start) { // Starting bit within `result`
1725 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1726
1727 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1728 if (mask != 0xffffffffu)
1729 field = builder.CreateAnd(field, builder.getInt32(mask));
1730
1731 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1732 field = builder.CreateShl(field, start);
1733
1734 return builder.CreateOr(result, field);
1735}
1736
1737void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1739 llvm::IRBuilderBase &builder) {
1740 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1741 llvm::Value *smemDesc = builder.getInt64(0);
1742
1743 smemDesc = packValInto64Bits(builder, smemDesc,
1744 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1745 smemDesc = packValInto64Bits(
1746 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1747 smemDesc = packValInto64Bits(
1748 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1749
1750 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1751 smemDesc = packValInto64Bits(builder, smemDesc,
1752 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1753 smemDesc = packValInto64Bits(
1754 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1755 smemDesc = packValInto64Bits(builder, smemDesc,
1756 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1757
1758 mt.mapValue(thisOp.getRes()) = smemDesc;
1759}
1760
1761//===----------------------------------------------------------------------===//
1762// getPtx methods
1763//===----------------------------------------------------------------------===//
1764
1765std::string NVVM::MBarrierInitOp::getPtx() {
1766 bool isShared = isPtrInSharedCTASpace(getAddr());
1767 return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
1768 : std::string("mbarrier.init.b64 [%0], %1;");
1769}
1770
1771std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
1772 bool isShared = isPtrInSharedCTASpace(getAddr());
1773 return isShared
1774 ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
1775 : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
1776}
1777
1778std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
1779 bool isShared = isPtrInSharedCTASpace(getAddr());
1780 llvm::StringRef space = isShared ? ".shared" : "";
1781
1782 return llvm::formatv("{\n\t"
1783 ".reg .pred P1; \n\t"
1784 "LAB_WAIT: \n\t"
1785 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
1786 "@P1 bra.uni DONE; \n\t"
1787 "bra.uni LAB_WAIT; \n\t"
1788 "DONE: \n\t"
1789 "}",
1790 space);
1791}
1792
1793//===----------------------------------------------------------------------===//
1794// getIntrinsicID/getIntrinsicIDAndArgs methods
1795//===----------------------------------------------------------------------===//
1796
1797mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
1798 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1799 auto thisOp = cast<NVVM::BarrierOp>(op);
1800 llvm::Value *barrierId = thisOp.getBarrierId()
1801 ? mt.lookupValue(thisOp.getBarrierId())
1802 : builder.getInt32(0);
1803 llvm::Intrinsic::ID id;
1805 if (thisOp.getNumberOfThreads()) {
1806 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
1807 args.push_back(barrierId);
1808 args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
1809 } else if (thisOp.getReductionOp()) {
1810 switch (*thisOp.getReductionOp()) {
1811 case NVVM::BarrierReduction::AND:
1812 id = llvm::Intrinsic::nvvm_barrier0_and;
1813 break;
1814 case NVVM::BarrierReduction::OR:
1815 id = llvm::Intrinsic::nvvm_barrier0_or;
1816 break;
1817 case NVVM::BarrierReduction::POPC:
1818 id = llvm::Intrinsic::nvvm_barrier0_popc;
1819 break;
1820 }
1821 args.push_back(mt.lookupValue(thisOp.getReductionPredicate()));
1822 } else {
1823 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
1824 args.push_back(barrierId);
1825 }
1826
1827 return {id, std::move(args)};
1828}
1829
1830mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
1831 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1832 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
1833 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1834 llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
1835 : llvm::Intrinsic::nvvm_mbarrier_init;
1836
1837 // Fill the Intrinsic Args
1839 args.push_back(mt.lookupValue(thisOp.getAddr()));
1840 args.push_back(mt.lookupValue(thisOp.getCount()));
1841
1842 return {id, std::move(args)};
1843}
1844
1845mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
1846 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1847 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
1848 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1849 llvm::Intrinsic::ID id = isShared
1850 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
1851 : llvm::Intrinsic::nvvm_mbarrier_inval;
1852
1853 return {id, {mt.lookupValue(thisOp.getAddr())}};
1854}
1855
1856mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
1857 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1858 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
1859 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1860 llvm::Intrinsic::ID id = isShared
1861 ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared
1862 : llvm::Intrinsic::nvvm_mbarrier_arrive;
1863
1864 return {id, {mt.lookupValue(thisOp.getAddr())}};
1865}
1866
1867mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
1868 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1869 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
1870 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1871 llvm::Intrinsic::ID id =
1872 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
1873 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
1874 // Fill the Intrinsic Args
1876 args.push_back(mt.lookupValue(thisOp.getAddr()));
1877 args.push_back(mt.lookupValue(thisOp.getCount()));
1878
1879 return {id, std::move(args)};
1880}
1881
1882mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
1883 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1884 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
1885 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1886 llvm::Intrinsic::ID id = isShared
1887 ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
1888 : llvm::Intrinsic::nvvm_mbarrier_test_wait;
1889 // Fill the Intrinsic Args
1891 args.push_back(mt.lookupValue(thisOp.getAddr()));
1892 args.push_back(mt.lookupValue(thisOp.getState()));
1893
1894 return {id, std::move(args)};
1895}
1896
1897mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
1898 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1899 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
1900 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1901
1902 llvm::Intrinsic::ID id;
1903 if (thisOp.getNoinc()) {
1904 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
1905 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
1906 } else {
1907 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
1908 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
1909 }
1910
1911 return {id, {mt.lookupValue(thisOp.getAddr())}};
1912}
1913
1914#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1915 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1916
1917#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1918 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1919
1920llvm::Intrinsic::ID
1921CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1923 llvm::Intrinsic::ID id;
1924
1925 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1926 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1927 switch (cpAsyncOp.getSize()) {
1928 case 4:
1929 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1930 break;
1931 case 8:
1932 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1933 break;
1934 case 16:
1935 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1936 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1937 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1938 break;
1939 default:
1940 llvm_unreachable("Invalid copy size in CpAsyncOp.");
1941 }
1942
1943 // Fill the Intrinsic Args
1944 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1945 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1946 if (hasCpSize)
1947 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1948
1949 return id;
1950}
1951
1952mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1953 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1954 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1956 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1957
1958 // Fill the Intrinsic Args
1959 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1960 args.push_back(mt.lookupValue(thisOp.getSize()));
1961
1962 mlir::Value cacheHint = thisOp.getL2CacheHint();
1963 const bool hasCacheHint = static_cast<bool>(cacheHint);
1964 llvm::Value *i64Unused =
1965 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1966 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1967 args.push_back(builder.getInt1(hasCacheHint));
1968
1969 return {id, std::move(args)};
1970}
1971
1972mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1973 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1974 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1976
1977 // Fill the Intrinsic Args: dst, mbar, src, size.
1978 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1979 args.push_back(mt.lookupValue(thisOp.getMbar()));
1980 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1981 args.push_back(mt.lookupValue(thisOp.getSize()));
1982
1983 // Multicast mask, if available.
1984 mlir::Value multicastMask = thisOp.getMulticastMask();
1985 const bool hasMulticastMask = static_cast<bool>(multicastMask);
1986 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
1987 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
1988
1989 // Cache hint, if available.
1990 mlir::Value cacheHint = thisOp.getL2CacheHint();
1991 const bool hasCacheHint = static_cast<bool>(cacheHint);
1992 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
1993 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1994
1995 // Flag arguments for multicast and cachehint.
1996 args.push_back(builder.getInt1(hasMulticastMask));
1997 args.push_back(builder.getInt1(hasCacheHint));
1998
1999 llvm::Intrinsic::ID id =
2000 llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
2001
2002 return {id, std::move(args)};
2003}
2004
2005mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2006 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2007 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
2009 llvm::Intrinsic::ID id =
2010 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
2011
2012 // Fill the Intrinsic Args
2013 args.push_back(mt.lookupValue(thisOp.getDstMem()));
2014 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2015 args.push_back(mt.lookupValue(thisOp.getSize()));
2016
2017 mlir::Value cacheHint = thisOp.getL2CacheHint();
2018 const bool hasCacheHint = static_cast<bool>(cacheHint);
2019 llvm::Value *i64Unused =
2020 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
2021 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
2022 args.push_back(builder.getInt1(hasCacheHint));
2023
2024 // Choose the bytemask variant
2025 if (mlir::Value byteMask = thisOp.getByteMask()) {
2026 args.push_back(mt.lookupValue(byteMask));
2027 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
2028 }
2029
2030 return {id, std::move(args)};
2031}
2032
2033bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
2034 RewriterBase &rewriter,
2035 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2036 &asmValues) {
2037 // Add all the operands but not the attrs to the asmValues list.
2038 // The attrs here are used to generate the right variants for
2039 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
2040 for (auto val : getOperands())
2041 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
2042
2043 return false;
2044}
2045
2047CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
2048 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2049 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
2050 const bool isCTAOnly = thisOp.getIsCTAOnly();
2052
2053 // Fill the Intrinsic Args
2054 args.push_back(mt.lookupValue(thisOp.getDstMem()));
2055 args.push_back(mt.lookupValue(thisOp.getMbar()));
2056 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2057
2058 // Coordinates and im2col-offsets
2059 for (mlir::Value v : thisOp.getCoordinates())
2060 args.push_back(mt.lookupValue(v));
2061 for (mlir::Value v : thisOp.getIm2colOffsets())
2062 args.push_back(mt.lookupValue(v));
2063
2064 // MulticastMask, if available
2065 mlir::Value mcMask = thisOp.getMulticastMask();
2066 const bool hasMC = static_cast<bool>(mcMask);
2067 llvm::Value *i16Zero =
2068 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
2069
2070 // CacheHint, if available
2071 mlir::Value cacheHint = thisOp.getL2CacheHint();
2072 const bool hasCacheHint = static_cast<bool>(cacheHint);
2073 llvm::Value *i64Zero =
2074 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
2075
2076 // Flag argument CTAGroup
2077 // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
2078 // Hence, the +1 to getGroup().
2079 const int32_t val =
2080 thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
2081 llvm::Value *cg =
2082 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
2083
2084 if (!isCTAOnly) {
2085 // For shared::cluster, all the arguments that we build are applicable.
2086 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
2087 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
2088 args.push_back(builder.getInt1(hasMC));
2089 args.push_back(builder.getInt1(hasCacheHint));
2090 args.push_back(cg);
2091 } else {
2092 // For shared::cta, only cache-hint is applicable.
2093 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
2094 args.push_back(builder.getInt1(hasCacheHint));
2095 }
2096
2097 constexpr size_t numDims = 5; // 1D to 5D
2098 constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
2099 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
2100 using TableTy = std::array<rowTy, numModes>;
2101 static constexpr TableTy IDTable{
2102 {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
2103 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
2104 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
2105 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
2106 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
2108 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
2109 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
2110 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
2112 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
2113 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
2114 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
2116 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
2117 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
2118 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
2120 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
2121
2122 static constexpr TableTy IDTableCTA{
2123 {{notIntrinsic,
2124 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
2125 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
2126 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
2127 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
2128 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
2130 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
2131 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
2132 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
2134 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
2135 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
2136 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
2138 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
2139 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
2140 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
2142 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
2143
2144 static_assert(
2145 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
2146 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
2147 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
2148 size_t mode = static_cast<size_t>(thisOp.getMode());
2149 size_t dim = thisOp.getCoordinates().size();
2150 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
2151 assert(id != notIntrinsic &&
2152 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
2153
2154 return {id, std::move(args)};
2155}
2156
2157mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
2158 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2159 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
2161
2162 // Fill the Intrinsic Args
2163 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2164
2165 for (auto v : thisOp.getCoordinates())
2166 args.push_back(mt.lookupValue(v));
2167 for (auto v : thisOp.getIm2colOffsets())
2168 args.push_back(mt.lookupValue(v));
2169
2170 mlir::Value cacheHint = thisOp.getL2CacheHint();
2171 const bool hasCacheHint = static_cast<bool>(cacheHint);
2172 llvm::Value *i64Unused =
2173 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
2174 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
2175 args.push_back(builder.getInt1(hasCacheHint));
2176
2177 const unsigned NI = llvm::Intrinsic::not_intrinsic;
2178 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
2179 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
2180 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
2181 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
2182 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
2183 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
2184 {NI, NI, NI,
2185 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
2186 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
2187 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
2188 {NI, NI, NI,
2189 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
2190 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
2191 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
2192 {NI, NI, NI,
2193 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
2194 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
2195 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
2196 {NI, NI, NI, NI, NI,
2197 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
2198
2199 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
2200 "TMALoadModes must match number of rows in IDTable");
2201 size_t mode = static_cast<size_t>(thisOp.getMode());
2202 size_t dim = thisOp.getCoordinates().size();
2203 llvm::Intrinsic::ID id = IDTable[mode][dim];
2204 if (id == llvm::Intrinsic::not_intrinsic)
2205 llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
2206
2207 return {id, std::move(args)};
2208}
2209
2211CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2212 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2213 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
2215
2216 // Fill the Intrinsic Args
2217 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2218 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2219
2220 for (auto v : thisOp.getCoordinates())
2221 args.push_back(mt.lookupValue(v));
2222
2223 mlir::Value cacheHint = thisOp.getL2CacheHint();
2224 const bool hasCacheHint = static_cast<bool>(cacheHint);
2225 llvm::Value *i64Unused =
2226 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
2227 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
2228 args.push_back(builder.getInt1(hasCacheHint));
2229
2230 const unsigned NI = llvm::Intrinsic::not_intrinsic;
2231 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
2232 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
2233 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
2234 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
2235 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
2236 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
2237 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
2238 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
2239 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
2240 {NI, NI, NI, NI, NI,
2241 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
2242
2243 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
2244 "TMAStoreModes must match number of rows in IDTable");
2245 size_t mode = static_cast<size_t>(thisOp.getMode());
2246 size_t dim = thisOp.getCoordinates().size();
2247 llvm::Intrinsic::ID id = IDTable[mode][dim];
2248 if (id == llvm::Intrinsic::not_intrinsic)
2249 llvm_unreachable(
2250 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
2251
2252 return {id, std::move(args)};
2253}
2254
2255NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
2256 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2257 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
2258 llvm::LLVMContext &ctx = mt.getLLVMContext();
2259
2261
2262 // Arguments to the intrinsic:
2263 // shared_mem_ptr, tmaDesc, tensorDims
2264 // cache_hint(if applicable) and flag(boolean)
2265 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2266 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2267
2268 for (Value v : thisOp.getCoordinates())
2269 args.push_back(mt.lookupValue(v));
2270
2271 mlir::Value cacheHint = thisOp.getL2CacheHint();
2272 const bool hasCacheHint = static_cast<bool>(cacheHint);
2273 llvm::Value *i64ZeroValue =
2274 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2275 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
2276 args.push_back(builder.getInt1(hasCacheHint));
2277
2278 const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
2279
2280 constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
2281 constexpr unsigned numLayouts = 2; // TILE, IM2COL
2282 constexpr unsigned maxDim = 5; // 1D to 5D
2283 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
2284 using layoutTable = std::array<row, numLayouts>;
2285 using fullTable = std::array<layoutTable, numRedKinds>;
2286 static constexpr fullTable IDTable{
2287 {// RedTy::ADD
2288 {{{{notIntrinsic,
2289 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
2290 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
2291 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
2292 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
2293 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
2295 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
2296 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
2297 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
2298 // RedTy::MIN
2299 {{{{notIntrinsic,
2300 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
2301 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
2302 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
2303 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
2304 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
2306 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
2307 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
2308 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
2309 // RedTy::MAX
2310 {{{{notIntrinsic,
2311 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
2312 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
2313 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
2314 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
2315 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
2317 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
2318 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
2319 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
2320 // RedTy::INC
2321 {{{{notIntrinsic,
2322 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
2323 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
2324 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
2325 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
2326 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
2328 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
2329 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
2330 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
2331 // RedTy::DEC
2332 {{{{notIntrinsic,
2333 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
2334 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
2335 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
2336 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
2337 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
2339 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
2340 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
2341 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
2342 // RedTy::AND
2343 {{{{notIntrinsic,
2344 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
2345 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
2346 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
2347 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
2348 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
2350 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
2351 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
2352 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
2353 // RedTy::OR
2354 {{{{notIntrinsic,
2355 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
2356 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
2357 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
2358 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
2359 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
2361 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
2362 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
2363 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
2364 // RedTy::XOR
2365 {{{{notIntrinsic,
2366 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
2367 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
2368 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
2369 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
2370 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
2372 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
2373 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
2374 llvm::Intrinsic::
2375 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
2376
2377 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
2378 "TMAReduxKinds must match number of rows in IDTable");
2379
2380 size_t redKind = static_cast<size_t>(thisOp.getRedKind());
2381 size_t mode = static_cast<size_t>(thisOp.getMode());
2382 size_t dim = thisOp.getCoordinates().size();
2383
2384 assert(redKind < IDTable.size() &&
2385 "Invalid redKind for CpAsyncBulkTensorReduceOp");
2386 assert(mode < IDTable[redKind].size() &&
2387 "Invalid mode for CpAsyncBulkTensorReduceOp");
2388 assert(dim < IDTable[redKind][mode].size() &&
2389 "Invalid dim for CpAsyncBulkTensorReduceOp");
2390
2391 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
2392
2393 assert(intrinsicID != notIntrinsic &&
2394 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
2395
2396 return {intrinsicID, std::move(args)};
2397}
2398
2399#define _none
2400
2401#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2402 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
2403 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
2404
2405#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
2406 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2407 : CVT_F2TF32_ID_IMPL(rnd, relu, )
2408
2409llvm::Intrinsic::ID
2410ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2411 NVVM::SaturationMode sat, bool hasRelu) {
2412 using RndMode = NVVM::FPRoundingMode;
2413 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2414 switch (rnd) {
2415 case RndMode::RN:
2416 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
2417 case RndMode::RZ:
2418 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
2419 case RndMode::RNA:
2420 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
2421 default:
2422 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
2423 }
2424}
2425
2427ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
2429 llvm::IRBuilderBase &builder) {
2431 args.push_back(mt.lookupValue(op.getA()));
2432 args.push_back(mt.lookupValue(op.getB()));
2433
2434 bool hasRelu = op.getRelu();
2435
2436 llvm::Intrinsic::ID intId =
2437 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
2438 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
2439
2440 return {intId, std::move(args)};
2441}
2442
2443#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
2444 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
2445 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
2446
2447llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
2448 bool hasRelu) {
2450 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2451 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
2452 })
2453 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2454 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
2455 })
2456 .Default([](mlir::Type) {
2457 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
2458 return llvm::Intrinsic::not_intrinsic;
2459 });
2460}
2461
2462#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
2463 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
2464 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
2465
2466#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
2467 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
2468 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
2469
2470llvm::Intrinsic::ID
2471ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2472 NVVM::SaturationMode sat, bool hasRelu) {
2473 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2474 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2475 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2476
2478 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2479 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
2480 })
2481 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2482 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
2483 })
2484 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2485 if (hasRoundingModeRZ)
2486 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
2487 else if (hasRoundingModeRP)
2488 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
2489
2490 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
2491 })
2492 .Default([](mlir::Type) {
2493 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
2494 return llvm::Intrinsic::not_intrinsic;
2495 });
2496}
2497
2498#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
2499 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2500 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2501
2502llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
2503 bool hasRelu) {
2505 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2506 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
2507 })
2508 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2509 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
2510 })
2511 .Default([](mlir::Type) {
2512 llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
2513 return llvm::Intrinsic::not_intrinsic;
2514 });
2515}
2516
2517#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
2518 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
2519 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
2520
2521llvm::Intrinsic::ID
2522ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2523 NVVM::SaturationMode sat) {
2524 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2525 switch (rnd) {
2526 case NVVM::FPRoundingMode::RZ:
2527 return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
2528 case NVVM::FPRoundingMode::RP:
2529 return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
2530 default:
2531 llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
2532 }
2533}
2534
2535NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
2536 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2537 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
2538
2539 bool hasRelu = curOp.getRelu();
2540
2541 llvm::Intrinsic::ID intId =
2543 .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
2544 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
2545 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
2546 })
2547 .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
2548 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
2549 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
2550 })
2551 .Default([](mlir::Type type) {
2552 llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
2553 return llvm::Intrinsic::not_intrinsic;
2554 });
2555
2556 llvm::Value *packedI16 =
2557 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2558 llvm::Type::getInt16Ty(builder.getContext()));
2559
2560 return {intId, {packedI16}};
2561}
2562
2563NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
2564 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2565 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
2566
2567 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
2568 llvm::Value *packedI16 =
2569 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2570 llvm::Type::getInt16Ty(builder.getContext()));
2571
2572 return {intId, {packedI16}};
2573}
2574
2575NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
2576 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2577 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
2578
2579 bool hasRelu = curOp.getRelu();
2580
2581 llvm::Intrinsic::ID intId =
2583 .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
2584 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
2585 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
2586 })
2587 .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
2588 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
2589 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
2590 })
2591 .Default([](mlir::Type type) {
2592 llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
2593 return llvm::Intrinsic::not_intrinsic;
2594 });
2595
2596 llvm::Value *packedI16 =
2597 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2598 llvm::Type::getInt16Ty(builder.getContext()));
2599
2600 return {intId, {packedI16}};
2601}
2602
2603NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
2604 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2605 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
2606
2607 bool hasRelu = curOp.getRelu();
2608
2609 llvm::Intrinsic::ID intId =
2611 .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
2612 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
2613 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
2614 })
2615 .Default([](mlir::Type type) {
2616 llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
2617 return llvm::Intrinsic::not_intrinsic;
2618 });
2619
2620 llvm::Value *extendedI16 =
2621 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
2622 llvm::Type::getInt16Ty(builder.getContext()));
2623
2624 return {intId, {extendedI16}};
2625}
2626
2627llvm::Intrinsic::ID
2628Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
2631 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
2632 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2633 .getAddressSpace();
2634 bool isShared = as == NVVMMemorySpace::Shared;
2635 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2636
2637 llvm::Intrinsic::ID id;
2638 if (isShared) {
2639 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
2640 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
2641 } else {
2642 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
2643 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
2644 }
2645
2646 // Fill the Intrinsic Args
2647 args.push_back(mt.lookupValue(curOp.getAddr()));
2648 args.push_back(mt.lookupValue(curOp.getNCols()));
2649
2650 return id;
2651}
2652
2653llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
2656 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
2657 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
2658 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
2659 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
2660
2661 // Fill the Intrinsic Args
2662 args.push_back(mt.lookupValue(curOp.getTaddr()));
2663 args.push_back(mt.lookupValue(curOp.getNCols()));
2664
2665 return id;
2666}
2667
2668#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
2669 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
2670 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
2671
2672#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
2673 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
2674 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
2675
2676llvm::Intrinsic::ID
2677Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
2680 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
2681 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2682 .getAddressSpace();
2683 bool isShared = as == NVVMMemorySpace::Shared;
2684 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
2685 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2686
2687 llvm::Intrinsic::ID id =
2688 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
2689 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
2690
2691 // Fill the Intrinsic Args
2692 args.push_back(mt.lookupValue(curOp.getAddr()));
2693 if (hasMulticast)
2694 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
2695
2696 return id;
2697}
2698
2699#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
2700 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
2701
2702#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
2703 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
2704 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
2705
2706#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
2707 [&]() -> auto { \
2708 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
2709 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
2710 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
2711 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
2712 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
2713 }()
2714
2715llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
2716 bool hasRelu = getRelu();
2717 bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2718
2719 if (hasRelu && hasSatFinite)
2720 return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
2721 if (hasRelu)
2722 return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
2723 if (hasSatFinite)
2724 return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
2725 return llvm::Intrinsic::nvvm_ff2f16x2_rs;
2726}
2727
2728llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
2729 bool hasRelu = getRelu();
2730 bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2731
2732 if (hasRelu && hasSatFinite)
2733 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
2734 if (hasRelu)
2735 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
2736 if (hasSatFinite)
2737 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
2738 return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
2739}
2740
2741llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
2742 mlir::Type dstTy = getDstTy();
2743 bool hasRelu = getRelu();
2744
2746 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2747 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
2748 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
2749 })
2750 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2751 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
2752 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
2753 })
2754 .Default([](mlir::Type) {
2755 llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
2756 return llvm::Intrinsic::not_intrinsic;
2757 });
2758}
2759
2760llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
2761 mlir::Type dstTy = getDstTy();
2762 bool hasRelu = getRelu();
2763
2765 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2766 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
2767 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
2768 })
2769 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2770 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
2771 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
2772 })
2773 .Default([](mlir::Type) {
2774 llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
2775 return llvm::Intrinsic::not_intrinsic;
2776 });
2777}
2778
2779llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
2780 mlir::Type dstTy = getDstTy();
2781 bool hasRelu = getRelu();
2782
2784 .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
2785 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
2786 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
2787 })
2788 .Default([](mlir::Type) {
2789 llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
2790 return llvm::Intrinsic::not_intrinsic;
2791 });
2792}
2793
2794llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
2795 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
2796 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
2797 auto srcFmt = curOp.getSrcFormat();
2798 auto mc = curOp.getMulticast();
2799
2800 switch (curOp.getShape()) {
2801 case Tcgen05CpShape::SHAPE_128x256b:
2802 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
2803 case Tcgen05CpShape::SHAPE_128x128b:
2804 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
2805 case Tcgen05CpShape::SHAPE_4x256b:
2806 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
2807 case Tcgen05CpShape::SHAPE_32x128b:
2808 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
2809 case Tcgen05CpShape::SHAPE_64x128b:
2810 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
2811 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
2812 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
2813 }
2814 llvm_unreachable("Invalid shape in tcgen05 cp Op");
2815}
2816
2817// Returns the valid vector length for a given shape and vector length, the
2818// function models the table mentioned in the tcgen05.{ld, st} Op description
2819static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
2820 unsigned vecLen) {
2821 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
2822 return vecLen >= 2;
2823 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
2824 return vecLen >= 4;
2825 return true;
2826}
2827
2828LogicalResult Tcgen05LdOp::verify() {
2829 LogicalResult result = success();
2830 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2831 result = emitError("shape 16x32bx2 requires offset argument");
2832
2833 if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
2834 result = emitError("offset argument is only supported for shape 16x32bx2");
2835
2836 auto resTy = getRes().getType();
2837 unsigned resLen = isa<VectorType>(resTy)
2838 ? llvm::cast<VectorType>(resTy).getNumElements()
2839 : 1;
2840 if (!isValidVectorLength(getShape(), resLen))
2841 result = emitError(llvm::formatv("invalid result type length {0} for shape "
2842 "{1} in tcgen05.ld Op",
2843 resLen, stringifyEnum(getShape())));
2844
2845 return result;
2846}
2847
2848LogicalResult Tcgen05StOp::verify() {
2849 LogicalResult result = success();
2850 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2851 result = emitError("shape 16x32bx2 requires offset argument");
2852
2853 auto valTy = getVal().getType();
2854 unsigned valLen = isa<VectorType>(valTy)
2855 ? llvm::cast<VectorType>(valTy).getNumElements()
2856 : 1;
2857 if (!isValidVectorLength(getShape(), valLen))
2858 result = emitError(llvm::formatv("invalid input length {0} for shape "
2859 "{1} in tcgen05.st Op",
2860 valLen, stringifyEnum(getShape())));
2861
2862 return result;
2863}
2864
2865/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
2866/// have ConstantRangeAttr.
2869 SetIntRangeFn setResultRanges) {
2870 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
2871 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
2872 rangeAttr.getLower(), rangeAttr.getUpper()});
2873 }
2874}
2875
2876/// Verify the range attribute satisfies LLVM ConstantRange constructor
2877/// requirements for NVVM SpecialRangeableRegisterOp.
2878static LogicalResult
2880 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
2881 if (!rangeAttr)
2882 return success();
2883
2884 const llvm::APInt &lower = rangeAttr->getLower();
2885 const llvm::APInt &upper = rangeAttr->getUpper();
2886
2887 // Check LLVM ConstantRange constructor condition
2888 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
2889 unsigned bitWidth = lower.getBitWidth();
2890 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
2891 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
2892 return op->emitOpError(
2893 "invalid range attribute: Lower == Upper, but they aren't min (")
2894 << llvm::toString(minVal, 10, false) << ") or max ("
2895 << llvm::toString(maxVal, 10, false)
2896 << ") value! This is an invalid constant range.";
2897 }
2898
2899 return success();
2900}
2901
2902static llvm::Value *getAsPackedI32(llvm::Value *arg,
2903 llvm::IRBuilderBase &builder) {
2904 return builder.CreateBitCast(arg,
2905 llvm::Type::getInt32Ty(builder.getContext()));
2906}
2907
2908NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
2909 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2910 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
2911
2913 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2914 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2915 args.push_back(mt.lookupValue(curOp.getC()));
2916
2917 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2918 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2919 unsigned type = (isASigned << 1) | isBSigned;
2920 const llvm::Intrinsic::ID ids[] = {
2921 llvm::Intrinsic::nvvm_idp4a_u_u,
2922 llvm::Intrinsic::nvvm_idp4a_u_s,
2923 llvm::Intrinsic::nvvm_idp4a_s_u,
2924 llvm::Intrinsic::nvvm_idp4a_s_s,
2925 };
2926 return {ids[type], args};
2927}
2928
2929NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
2930 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2931 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2932
2934 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2935 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2936 args.push_back(builder.getInt1(curOp.getBHi()));
2937 args.push_back(mt.lookupValue(curOp.getC()));
2938
2939 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2940 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2941 unsigned type = (isASigned << 1) | isBSigned;
2942 const llvm::Intrinsic::ID ids[] = {
2943 llvm::Intrinsic::nvvm_idp2a_u_u,
2944 llvm::Intrinsic::nvvm_idp2a_u_s,
2945 llvm::Intrinsic::nvvm_idp2a_s_u,
2946 llvm::Intrinsic::nvvm_idp2a_s_s,
2947 };
2948 return {ids[type], args};
2949}
2950
2951static llvm::Value *getParamCastedAddr(llvm::Value *addr,
2952 llvm::IRBuilderBase &builder) {
2953 return builder.CreateAddrSpaceCast(
2954 addr,
2955 llvm::PointerType::get(builder.getContext(),
2956 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2957}
2958
2960PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2962 llvm::IRBuilderBase &builder) {
2963 using MemSpace = NVVM::NVVMMemorySpace;
2964 using CacheLevel = NVVM::PrefetchCacheLevel;
2965
2966 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2967 std::optional<NVVM::CacheEvictionPriority> evictPriority =
2968 op.getEvictPriority();
2969 unsigned addressSpace =
2970 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2971 .getAddressSpace();
2972
2974 llvm::Value *addr = mt.lookupValue(op.getAddr());
2975 args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
2976 : addr);
2977
2978 if (op.getTensormap())
2979 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2980
2981 assert(cacheLevel && "expected cache level for non-tensormap prefetch");
2982
2983 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2984 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
2985
2986 if (evictPriority && *cacheLevel == CacheLevel::L2) {
2987 switch (*evictPriority) {
2988 case NVVM::CacheEvictionPriority::EvictLast:
2989 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
2990 case NVVM::CacheEvictionPriority::EvictNormal:
2991 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
2992 default:
2993 llvm_unreachable("Invalid cache eviction priority");
2994 }
2995 }
2996
2997 switch (static_cast<MemSpace>(addressSpace)) {
2998 case MemSpace::Generic:
2999 return *cacheLevel == CacheLevel::L1
3000 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
3001 : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
3002 case MemSpace::Global:
3003 return *cacheLevel == CacheLevel::L1
3005 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
3006 : NVVM::IDArgPair(
3007 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
3008 case MemSpace::Local:
3009 return *cacheLevel == CacheLevel::L1
3011 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
3012 : NVVM::IDArgPair(
3013 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
3014 default:
3015 llvm_unreachable("Invalid pointer address space");
3016 }
3017}
3018
3019bool NVVM::InlinePtxOp::getAsmValues(
3020 RewriterBase &rewriter,
3021 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3022 &asmValues) {
3023 for (auto arg : getReadWriteArgs())
3024 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
3025 for (auto arg : getResults())
3026 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
3027 for (auto arg : getReadOnlyArgs())
3028 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
3029 if (getPredicate())
3030 asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
3031 return false; // No manual mapping needed
3032}
3033
3034NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
3035 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3036 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
3038 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
3039 args.push_back(mt.lookupValue(curOp.getMbarrier()));
3040
3041 llvm::Intrinsic::ID intrinsicID =
3042 curOp.getMulticast()
3043 ? llvm::Intrinsic::
3044 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
3045 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
3046
3047 return {intrinsicID, args};
3048}
3049
3050NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
3051 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3052 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
3054 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
3055
3056 llvm::Intrinsic::ID intrinsicID;
3057
3058 switch (curOp.getQueryType()) {
3059 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
3060 intrinsicID =
3061 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
3062 break;
3063 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
3064 intrinsicID = llvm::Intrinsic::
3065 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
3066 break;
3067 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
3068 intrinsicID = llvm::Intrinsic::
3069 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
3070 break;
3071 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
3072 intrinsicID = llvm::Intrinsic::
3073 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
3074 break;
3075 }
3076 return {intrinsicID, args};
3077}
3078
3079//===----------------------------------------------------------------------===//
3080// NVVMDialect initialization, type parsing, and registration.
3081//===----------------------------------------------------------------------===//
3082
3083// TODO: This should be the llvm.nvvm dialect once this is supported.
3084void NVVMDialect::initialize() {
3085 addOperations<
3086#define GET_OP_LIST
3087#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
3088 >();
3089 addAttributes<
3090#define GET_ATTRDEF_LIST
3091#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
3092 >();
3093
3094 // Support unknown operations because not all NVVM operations are
3095 // registered.
3096 allowUnknownOperations();
3097 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
3098 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
3099}
3100
3101LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
3102 NamedAttribute attr) {
3103 StringAttr attrName = attr.getName();
3104 // Kernel function attribute should be attached to functions.
3105 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
3106 if (!isa<LLVM::LLVMFuncOp>(op)) {
3107 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
3108 << "' attribute attached to unexpected op";
3109 }
3110 }
3111 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
3112 // dim
3113 if (attrName == NVVMDialect::getMaxntidAttrName() ||
3114 attrName == NVVMDialect::getReqntidAttrName() ||
3115 attrName == NVVMDialect::getClusterDimAttrName()) {
3116 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
3117 if (!values || values.empty() || values.size() > 3) {
3118 return op->emitError()
3119 << "'" << attrName
3120 << "' attribute must be integer array with maximum 3 index";
3121 }
3122 }
3123 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
3124 // attribute
3125 if (attrName == NVVMDialect::getMinctasmAttrName() ||
3126 attrName == NVVMDialect::getMaxnregAttrName() ||
3127 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
3128 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
3129 return op->emitError()
3130 << "'" << attrName << "' attribute must be integer constant";
3131 }
3132 }
3133 // blocksareclusters must be used along with reqntid and cluster_dim
3134 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
3135 if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
3136 !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
3137 return op->emitError()
3138 << "'" << attrName << "' attribute must be used along with "
3139 << "'" << NVVMDialect::getReqntidAttrName() << "' and "
3140 << "'" << NVVMDialect::getClusterDimAttrName() << "'";
3141 }
3142 }
3143
3144 return success();
3145}
3146
3147LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
3148 unsigned regionIndex,
3149 unsigned argIndex,
3150 NamedAttribute argAttr) {
3151 auto funcOp = dyn_cast<FunctionOpInterface>(op);
3152 if (!funcOp)
3153 return success();
3154
3155 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
3156 StringAttr attrName = argAttr.getName();
3157 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
3158 if (!isKernel) {
3159 return op->emitError()
3160 << "'" << attrName
3161 << "' attribute must be present only on kernel arguments";
3162 }
3163 if (!isa<UnitAttr>(argAttr.getValue()))
3164 return op->emitError() << "'" << attrName << "' must be a unit attribute";
3165 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
3166 return op->emitError()
3167 << "'" << attrName
3168 << "' attribute requires the argument to also have attribute '"
3169 << LLVM::LLVMDialect::getByValAttrName() << "'";
3170 }
3171 }
3172
3173 return success();
3174}
3175
3176//===----------------------------------------------------------------------===//
3177// NVVM Address Space Attr
3178//===----------------------------------------------------------------------===//
3179
3180unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
3181 return static_cast<unsigned>(getValue());
3182}
3183
3184bool NVVMMemorySpaceAttr::isValidLoad(
3185 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
3186 const ::mlir::DataLayout *dataLayout,
3188 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
3189 dataLayout, emitError);
3190}
3191
3192bool NVVMMemorySpaceAttr::isValidStore(
3193 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
3194 const ::mlir::DataLayout *dataLayout,
3196 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
3197 dataLayout, emitError);
3198}
3199
3200bool NVVMMemorySpaceAttr::isValidAtomicOp(
3201 ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
3202 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
3204 // TODO: update this method once `ptr.atomic_rmw` is implemented.
3205 assert(false && "unimplemented, see TODO in the source.");
3206 return false;
3207}
3208
3209bool NVVMMemorySpaceAttr::isValidAtomicXchg(
3210 Type type, ptr::AtomicOrdering successOrdering,
3211 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
3212 const ::mlir::DataLayout *dataLayout,
3214 // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
3215 assert(false && "unimplemented, see TODO in the source.");
3216 return false;
3217}
3218
3219bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
3220 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
3221 // TODO: update this method once the `ptr.addrspace_cast` op is added to the
3222 // dialect.
3223 assert(false && "unimplemented, see TODO in the source.");
3224 return false;
3225}
3226
3227bool NVVMMemorySpaceAttr::isValidPtrIntCast(
3228 Type intLikeTy, Type ptrLikeTy,
3230 // TODO: update this method once the int-cast ops are added to the `ptr`
3231 // dialect.
3232 assert(false && "unimplemented, see TODO in the source.");
3233 return false;
3234}
3235
3236//===----------------------------------------------------------------------===//
3237// NVVM target attribute.
3238//===----------------------------------------------------------------------===//
3239LogicalResult
3240NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
3241 int optLevel, StringRef triple, StringRef chip,
3242 StringRef features, DictionaryAttr flags,
3243 ArrayAttr files, bool verifyTarget) {
3244 if (optLevel < 0 || optLevel > 3) {
3245 emitError() << "The optimization level must be a number between 0 and 3.";
3246 return failure();
3247 }
3248 if (triple.empty()) {
3249 emitError() << "The target triple cannot be empty.";
3250 return failure();
3251 }
3252 if (chip.empty()) {
3253 emitError() << "The target chip cannot be empty.";
3254 return failure();
3255 }
3256 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
3257 return mlir::isa_and_nonnull<StringAttr>(attr);
3258 })) {
3259 emitError() << "All the elements in the `link` array must be strings.";
3260 return failure();
3261 }
3262 return success();
3263}
3264
3265LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
3266 if (!getVerifyTarget())
3267 return success();
3268
3269 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
3270 if (!gpuModuleOp) {
3271 return emitError(gpuModule->getLoc(),
3272 "NVVM target attribute must be attached to a GPU module");
3273 }
3274
3275 const NVVMCheckSMVersion targetSMVersion =
3277 if (!targetSMVersion.isMinimumSMVersion()) {
3278 return emitError(gpuModule->getLoc(),
3279 "Minimum NVVM target SM version is sm_20");
3280 }
3281
3282 gpuModuleOp->walk([&](Operation *op) {
3283 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
3284 const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
3285 if (!requirement.isCompatibleWith(targetSMVersion)) {
3286 op->emitOpError() << "is not supported on " << getChip();
3287 return WalkResult::interrupt();
3288 }
3289 }
3290 return WalkResult::advance();
3291 });
3292
3293 return success();
3294}
3295
3296#define GET_OP_CLASSES
3297#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
3298
3299#define GET_ATTRDEF_CLASSES
3300#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
#define _none
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
FloatType getF32Type()
Definition Builders.cpp:43
IntegerType getI32Type()
Definition Builders.cpp:63
FloatType getF16Type()
Definition Builders.cpp:39
MLIRContext * getContext() const
Definition Builders.h:56
FloatType getF64Type()
Definition Builders.cpp:45
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition NVVMDialect.h:53
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
This represents an operation in an abstracted form, suitable for use with the builder APIs.