MLIR 22.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
39#include <cassert>
40#include <optional>
41#include <string>
42
43using namespace mlir;
44using namespace NVVM;
45
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
50
51//===----------------------------------------------------------------------===//
52// Helper/Utility methods
53//===----------------------------------------------------------------------===//
54
55static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
57 return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
58}
59
61 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
62}
63
64// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
65static llvm::nvvm::CTAGroupKind
66getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
67 switch (ctaGroup) {
68 case NVVM::CTAGroupKind::CTA_1:
69 return llvm::nvvm::CTAGroupKind::CG_1;
70 case NVVM::CTAGroupKind::CTA_2:
71 return llvm::nvvm::CTAGroupKind::CG_2;
72 }
73 llvm_unreachable("unsupported cta_group value");
74}
75
76//===----------------------------------------------------------------------===//
77// Verifier methods
78//===----------------------------------------------------------------------===//
79
80// This verifier is shared among the following Ops:
81// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
82// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
83static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
84 bool isIm2Col,
85 size_t numIm2ColOffsets,
86 Location loc) {
87 if (tensorDims < 1 || tensorDims > 5)
88 return emitError(loc, "expects coordinates between 1 to 5 dimension");
89
90 // For Im2Col mode, there are two constraints:
91 if (isIm2Col) {
92 // 1. Tensor must always be at least 3-d.
93 if (tensorDims < 3)
94 return emitError(
95 loc,
96 "to use im2col mode, the tensor has to be at least 3-dimensional");
97 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
98 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
99 return emitError(
100 loc, "im2col offsets must be 2 less than number of coordinates");
101 }
102 return success();
103}
104
105LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
106 TMAStoreMode mode = getMode();
107 // We lower through inline-ptx when getPredicate() is true.
108 // a) Only TILE mode is supported
109 // b) Cache-hint is not supported
110 if (getPredicate()) {
111 if (mode != TMAStoreMode::TILE)
112 return emitError("Inline-ptx lowering supported only for Tile mode.");
113 if (getL2CacheHint())
114 return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
115 }
116
117 size_t dims = getCoordinates().size();
118 switch (mode) {
119 case TMAStoreMode::TILE:
120 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
121 case TMAStoreMode::IM2COL:
122 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
123 case TMAStoreMode::TILE_SCATTER4:
124 if (dims != 5)
125 return emitError("Scatter4 mode expects 5 coordinates");
126 }
127 return success();
128}
129
130LogicalResult CpAsyncOp::verify() {
131 if (getModifier() != LoadCacheModifierKind::CG &&
132 getModifier() != LoadCacheModifierKind::CA)
133 return emitError("Only CG and CA cache modifiers are supported.");
134 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
135 return emitError("expected byte size to be either 4, 8 or 16.");
136 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
137 return emitError("CG cache modifier is only support for 16 bytes copy.");
138 return success();
139}
140
141// This verify params can be shared across TMA Load and Prefetch Ops.
142static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
143 TMALoadMode mode, Location loc) {
144 if (tensorDims < 1 || tensorDims > 5)
145 return emitError(loc, "expects coordinates between 1 to 5 dimension");
146
147 auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
148 size_t expectedIm2colOff) -> LogicalResult {
149 if (isIm2col && (tensorDims < 3))
150 return emitError(loc)
151 << "to use " << stringifyEnum(mode)
152 << " mode, the tensor has to be at least 3-dimensional";
153
154 if (numIm2colOff != expectedIm2colOff)
155 return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
156 << " (provided " << numIm2colOff << ")";
157
158 return success();
159 };
160
161 switch (mode) {
162 case TMALoadMode::TILE:
163 return checkTMALoadParams(mode, false, 0);
164 case TMALoadMode::IM2COL:
165 return checkTMALoadParams(mode, true, tensorDims - 2);
166 case TMALoadMode::IM2COL_W:
167 case TMALoadMode::IM2COL_W_128:
168 return checkTMALoadParams(mode, true, 2);
169 case TMALoadMode::TILE_GATHER4:
170 return (tensorDims == 5)
171 ? checkTMALoadParams(mode, false, 0)
172 : emitError(loc, "Gather4 mode expects 5 coordinates");
173 }
174 return success();
175}
176
177LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
178 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
179 getMode(), getLoc());
180}
181
182LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
183 TMALoadMode mode = getMode();
184 bool isCTAOnly = getIsCTAOnly();
185 if (getPredicate()) { // Inline-asm based lowering
186 if (isCTAOnly)
187 return emitError("Predicate is supported only for shared::cluster mode.");
188 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
189 return emitError(
190 "Predicate is supported only for Tile and Im2col modes.");
191 } else { // Intrinsics-based lowering
192 NVVMMemorySpace expectedAS =
193 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
194 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
195 .getAddressSpace();
196 if (AS != expectedAS)
197 return emitError()
198 << (isCTAOnly
199 ? "Shared::cta destination requires address-space 3."
200 : "Shared::cluster destination requires address-space 7.");
201 // Checks specific to shared::cta mode
202 if (isCTAOnly) {
203 if (getMulticastMask())
204 return emitError("Multicast is not supported with shared::cta mode.");
205 if (getGroup())
206 return emitError("CTAGroup is not supported with shared::cta mode.");
207 }
208 }
209
210 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
211 getMode(), getLoc());
212}
213
214LogicalResult CpAsyncBulkTensorReduceOp::verify() {
215 TMAStoreMode mode = getMode();
216 size_t dims = getCoordinates().size();
217 switch (mode) {
218 case TMAStoreMode::TILE:
219 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
220 case TMAStoreMode::IM2COL:
221 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
222 case TMAStoreMode::TILE_SCATTER4:
223 return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
224 }
225 return success();
226}
227
228LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
229 bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
230 if (isSharedCTA && getMulticastMask())
231 return emitError("Multicast is not supported with shared::cta mode.");
232
233 return success();
234}
235
236LogicalResult ConvertFloatToTF32Op::verify() {
237 using RndMode = NVVM::FPRoundingMode;
238 switch (getRnd()) {
239 case RndMode::RNA:
240 if (getRelu())
241 return emitError("Relu not supported with rna rounding mode.");
242 break;
243 case RndMode::RN:
244 case RndMode::RZ:
245 break;
246 default:
247 return emitError(
248 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
249 }
250 return success();
251}
252
253LogicalResult ConvertF32x2ToF6x2Op::verify() {
255
256 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
257 return emitOpError("Only ")
258 << mlir::Float6E2M3FNType::get(ctx) << " and "
259 << mlir::Float6E3M2FNType::get(ctx)
260 << " types are supported for conversions from f32x2 to f6x2.";
261 }
262 return success();
263}
264
265LogicalResult ConvertF32x2ToF8x2Op::verify() {
266 using RndMode = NVVM::FPRoundingMode;
267 using SatMode = NVVM::SaturationMode;
268
269 bool isRoundingModeRN = getRnd() == RndMode::RN;
270 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
271 bool isRoundingModeRP = getRnd() == RndMode::RP;
272 bool isSatFinite = getSat() == SatMode::SATFINITE;
273
274 bool hasRelu = getRelu();
275
277
279 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
280 [&](mlir::Type) -> LogicalResult {
281 if (!isRoundingModeRN) {
282 return emitOpError("Only RN rounding mode is supported for "
283 "conversions from f32x2 to ")
284 << mlir::Float8E4M3FNType::get(ctx) << " and "
285 << mlir::Float8E5M2Type::get(ctx) << " types";
286 }
287 if (!isSatFinite) {
288 return emitOpError("Only SATFINITE saturation mode is supported "
289 "for conversions "
290 "from f32x2 to ")
291 << mlir::Float8E4M3FNType::get(ctx) << " and "
292 << mlir::Float8E5M2Type::get(ctx) << " types";
293 }
294 return success();
295 })
296 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
297 if (!(isRoundingModeRZ || isRoundingModeRP)) {
298 return emitOpError("Only RZ and RP rounding modes are supported for "
299 "conversions from f32x2 to ")
300 << mlir::Float8E8M0FNUType::get(ctx) << " type";
301 }
302 if (hasRelu) {
303 return emitOpError("relu not supported for conversions to ")
304 << mlir::Float8E8M0FNUType::get(ctx) << " type";
305 }
306 return success();
307 })
308 .Default([&](mlir::Type) {
309 return emitOpError("Only ")
310 << mlir::Float8E4M3FNType::get(ctx) << ", "
311 << mlir::Float8E5M2Type::get(ctx) << ", and "
312 << mlir::Float8E8M0FNUType::get(ctx)
313 << " types are "
314 "supported for conversions from f32x2 to f8x2";
315 });
316}
317
318LogicalResult ConvertF16x2ToF8x2Op::verify() {
320
321 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
322 return emitOpError("Only ")
323 << mlir::Float8E4M3FNType::get(ctx) << " and "
324 << mlir::Float8E5M2Type::get(ctx)
325 << " types are supported for conversions from f16x2 to f8x2.";
326 }
327 return success();
328}
329
330LogicalResult ConvertBF16x2ToF8x2Op::verify() {
331 using RndMode = NVVM::FPRoundingMode;
332
333 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
334 return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
335 << " type is supported for conversions from "
336 "bf16x2 to f8x2.";
337
338 auto rnd = getRnd();
339 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
340 return emitOpError("Only RZ and RP rounding modes are supported for "
341 "conversions from bf16x2 to f8x2.");
342
343 return success();
344}
345
346LogicalResult ConvertF32x2ToF4x2Op::verify() {
348
349 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
350 return emitOpError("Only ")
351 << mlir::Float4E2M1FNType::get(ctx)
352 << " type is supported for conversions from f32x2 to f4x2.";
353
354 return success();
355}
356
357LogicalResult ConvertF8x2ToF16x2Op::verify() {
359
360 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
361 return emitOpError("Only ")
362 << mlir::Float8E4M3FNType::get(ctx) << " and "
363 << mlir::Float8E5M2Type::get(ctx)
364 << " types are supported for conversions from f8x2 to f16x2.";
365
366 return success();
367}
368
369LogicalResult ConvertF8x2ToBF16x2Op::verify() {
371 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
372 return emitOpError("Only ")
373 << mlir::Float8E8M0FNUType::get(ctx)
374 << " type is supported for conversions from f8x2 to bf16x2.";
375
376 return success();
377}
378
379LogicalResult ConvertF6x2ToF16x2Op::verify() {
381
382 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
383 return emitOpError("Only ")
384 << mlir::Float6E2M3FNType::get(ctx) << " and "
385 << mlir::Float6E3M2FNType::get(ctx)
386 << " types are supported for conversions from f6x2 to f16x2.";
387
388 return success();
389}
390
391LogicalResult ConvertF4x2ToF16x2Op::verify() {
393
394 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
395 return emitOpError("Only ")
396 << mlir::Float4E2M1FNType::get(ctx)
397 << " type is supported for conversions from f4x2 to f16x2.";
398
399 return success();
400}
401
402//===----------------------------------------------------------------------===//
403// Stochastic Rounding Conversion Ops
404//===----------------------------------------------------------------------===//
405
406LogicalResult ConvertF32x2ToF16x2Op::verify() {
407 if (getRnd() != FPRoundingMode::RS)
408 return emitOpError("Only RS rounding mode is supported for "
409 "conversions from f32x2 to f16x2.");
410 return success();
411}
412
413LogicalResult ConvertF32x2ToBF16x2Op::verify() {
414 if (getRnd() != FPRoundingMode::RS)
415 return emitOpError("Only RS rounding mode is supported for "
416 "conversions from f32x2 to bf16x2.");
417 return success();
418}
419
420LogicalResult ConvertF32x4ToF8x4Op::verify() {
422
423 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
424 return emitOpError("Only ")
425 << mlir::Float8E4M3FNType::get(ctx) << " and "
426 << mlir::Float8E5M2Type::get(ctx)
427 << " types are supported for conversions from f32x4 to f8x4.";
428
429 return success();
430}
431
432LogicalResult ConvertF32x4ToF6x4Op::verify() {
434
435 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
436 return emitOpError("Only ")
437 << mlir::Float6E2M3FNType::get(ctx) << " and "
438 << mlir::Float6E3M2FNType::get(ctx)
439 << " types are supported for conversions from f32x4 to f6x4.";
440
441 return success();
442}
443
444LogicalResult ConvertF32x4ToF4x4Op::verify() {
446
447 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
448 return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
449 << " type is supported for conversions from "
450 "f32x4 to f4x4.";
451
452 return success();
453}
454
455LogicalResult BulkStoreOp::verify() {
456 if (getInitVal() != 0)
457 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
458 return success();
459}
460
461LogicalResult PMEventOp::verify() {
462 auto eventId = getEventId();
463 auto maskedEventId = getMaskedEventId();
464 if (!maskedEventId && !eventId) {
465 return emitOpError() << "either `id` or `mask` must be set";
466 }
467
468 if (maskedEventId && eventId) {
469 return emitOpError() << "`id` and `mask` cannot be set at the same time";
470 }
471
472 if (eventId) {
473 if (eventId < 0 || eventId > 15) {
474 return emitOpError() << "`id` must be between 0 and 15";
475 }
476 }
477
478 return llvm::success();
479}
480
481// Given the element type of an operand and whether or not it is an accumulator,
482// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
483// operand's element type.
484std::optional<mlir::NVVM::MMATypes>
485MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
486 auto half2Type =
487 VectorType::get(2, Float16Type::get(operandElType.getContext()));
488 if (operandElType.isF64())
489 return NVVM::MMATypes::f64;
490 if (operandElType.isF16() || operandElType == half2Type)
491 return NVVM::MMATypes::f16;
492 if (operandElType.isF32() && isAccumulator)
493 return NVVM::MMATypes::f32;
494 if (operandElType.isF32() && !isAccumulator)
495 return NVVM::MMATypes::tf32;
496 if (llvm::isa<IntegerType>(operandElType)) {
497 if (isAccumulator)
498 return NVVM::MMATypes::s32;
499 return std::nullopt;
500 }
501
502 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
503 if (structType.getBody().empty())
504 return std::nullopt;
505 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
506 }
507
508 return std::nullopt;
509}
510
511static bool isInt4PtxType(MMATypes type) {
512 return (type == MMATypes::u4 || type == MMATypes::s4);
513}
514
515static bool isInt8PtxType(MMATypes type) {
516 return (type == MMATypes::u8 || type == MMATypes::s8);
517}
518
519static bool isIntegerPtxType(MMATypes type) {
520 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
521 type == MMATypes::s32;
522}
523
524MMATypes MmaOp::accumPtxType() {
525 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
526 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
527 assert(val.has_value() && "accumulator PTX type should always be inferrable");
528 return val.value();
529}
530
531MMATypes MmaOp::resultPtxType() {
532 std::optional<mlir::NVVM::MMATypes> val =
533 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
534 assert(val.has_value() && "result PTX type should always be inferrable");
535 return val.value();
536}
537
538void MmaOp::print(OpAsmPrinter &p) {
539 SmallVector<Type, 4> regTypes;
540 struct OperandFragment {
541 StringRef operandName;
542 StringRef ptxTypeAttr;
543 SmallVector<Value, 4> regs;
544 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
545 : operandName(name), ptxTypeAttr(ptxTypeName) {}
546 };
547
548 std::array<OperandFragment, 3> frags{
549 OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
550 OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
551 OperandFragment("C", "")};
552 SmallVector<StringRef, 4> ignoreAttrNames{
553 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
554
555 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
556 auto &frag = frags[fragIdx];
557 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
558 for (auto operandIdx = varOperandSpec.first;
559 operandIdx < varOperandSpec.first + varOperandSpec.second;
560 operandIdx++) {
561 frag.regs.push_back(this->getOperand(operandIdx));
562 if (operandIdx == 0) {
563 regTypes.push_back(this->getOperand(operandIdx).getType());
564 }
565 }
566 std::optional<MMATypes> inferredType =
567 inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
568 if (inferredType)
569 ignoreAttrNames.push_back(frag.ptxTypeAttr);
570 }
571
572 auto printMmaOperand = [&](const OperandFragment &frag) -> void {
573 p << " " << frag.operandName;
574 p << "[";
575 p.printOperands(frag.regs);
576 p << "] ";
577 };
578
579 for (const auto &frag : frags) {
580 printMmaOperand(frag);
581 }
582
583 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
584
585 // Print the types of the operands and result.
586 p << " : "
587 << "(";
588 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
589 frags[1].regs[0].getType(),
590 frags[2].regs[0].getType()},
591 p);
592 p << ")";
593 p.printArrowTypeList(TypeRange{this->getRes().getType()});
594}
595
596void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
597 ValueRange operandA, ValueRange operandB, ValueRange operandC,
598 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
599 std::optional<MMAIntOverflow> intOverflow,
600 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
601 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
602
603 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
604 MLIRContext *ctx = builder.getContext();
605 result.addAttribute(
606 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
607
608 result.addOperands(operandA);
609 result.addOperands(operandB);
610 result.addOperands(operandC);
611
612 if (multiplicandPtxTypes) {
613 result.addAttribute("multiplicandAPtxType",
614 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
615 result.addAttribute("multiplicandBPtxType",
616 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
617 } else {
618 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
619 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
620 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
621 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
622 }
623
624 if (multiplicandLayouts) {
625 result.addAttribute("layoutA",
626 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
627 result.addAttribute("layoutB",
628 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
629 } else {
630 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
631 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
632 }
633
634 if (intOverflow.has_value())
635 result.addAttribute("intOverflowBehavior",
636 MMAIntOverflowAttr::get(ctx, *intOverflow));
637 if (b1Op.has_value())
638 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
639
640 result.addTypes(resultType);
641 result.addAttribute(
642 MmaOp::getOperandSegmentSizeAttr(),
643 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
644 static_cast<int32_t>(operandB.size()),
645 static_cast<int32_t>(operandC.size())}));
646}
647
648// <operation> :=
649// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
650// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
651// `->` type($res)
652ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
653 struct OperandFragment {
654 std::optional<MMATypes> elemtype;
655 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
656 SmallVector<Type> regTypes;
657 };
658
659 Builder &builder = parser.getBuilder();
660 std::array<OperandFragment, 4> frags;
661
662 NamedAttrList namedAttributes;
663
664 // A helper to parse the operand segments.
665 auto parseMmaOperand = [&](StringRef operandName,
666 OperandFragment &frag) -> LogicalResult {
667 if (parser.parseKeyword(operandName).failed())
668 return failure();
669 if (parser
670 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
671 .failed())
672 return failure();
673 return success();
674 };
675
676 // Parse the operand segments.
677 if (parseMmaOperand("A", frags[0]).failed())
678 return failure();
679 if (parseMmaOperand("B", frags[1]).failed())
680 return failure();
681 if (parseMmaOperand("C", frags[2]).failed())
682 return failure();
683
684 if (parser.parseOptionalAttrDict(namedAttributes).failed())
685 return failure();
686
687 // Parse the type specification and resolve operands.
688 SmallVector<Type, 3> operandTypes;
689 if (failed(parser.parseColon()))
690 return failure();
691 if (failed(parser.parseLParen()))
692 return failure();
693 if (failed(parser.parseTypeList(operandTypes)))
694 return failure();
695 if (failed(parser.parseRParen()))
696 if (operandTypes.size() != 3)
697 return parser.emitError(
698 parser.getNameLoc(),
699 "expected one type for each operand segment but got " +
700 Twine(operandTypes.size()) + " types");
701 for (const auto &iter : llvm::enumerate(operandTypes)) {
702 auto &frag = frags[iter.index()];
703 frag.regTypes.resize(frag.regs.size(), iter.value());
704 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
705 parser.getNameLoc(), result.operands)))
706 return failure();
707 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
708 /*isAccumulator*/ iter.index() < 2);
709 }
710
711 Type resultType;
712 if (parser.parseArrow() || parser.parseType(resultType))
713 return failure();
714 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
715
716 std::array<StringRef, 2> names{"multiplicandAPtxType",
717 "multiplicandBPtxType"};
718 for (unsigned idx = 0; idx < names.size(); idx++) {
719 const auto &frag = frags[idx];
720 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
721 if (!frag.elemtype.has_value() && !attr.has_value()) {
722 return parser.emitError(
723 parser.getNameLoc(),
724 "attribute " + names[idx] +
725 " is not provided explicitly and cannot be inferred");
726 }
727 if (!attr.has_value())
728 result.addAttribute(
729 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
730 }
731
732 result.addTypes(resultType);
733 if (!namedAttributes.empty())
734 result.addAttributes(namedAttributes);
735 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
736 builder.getDenseI32ArrayAttr({
737 static_cast<int32_t>(frags[0].regs.size()),
738 static_cast<int32_t>(frags[1].regs.size()),
739 static_cast<int32_t>(frags[2].regs.size()),
740 }));
741 return success();
742}
743
744LogicalResult MmaOp::verify() {
745 MLIRContext *context = getContext();
746 auto f16Ty = Float16Type::get(context);
747 auto i32Ty = IntegerType::get(context, 32);
748 auto f16x2Ty = VectorType::get(2, f16Ty);
749 auto f32Ty = Float32Type::get(context);
750 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
751 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
752
753 auto s32x4StructTy =
754 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
755 auto f32x8StructTy =
756 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
757 auto f16x2x2StructTy =
758 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
759 auto f32x4StructTy =
760 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
761 auto s32x2StructTy =
762 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
763
764 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
765 getShapeAttr().getK()};
766
767 // These variables define the set of allowed data types for matrices A, B, C,
768 // and result.
769 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
770 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
771 AllowedShapes allowedShapes;
772 AllowedTypes expectedA;
773 AllowedTypes expectedB;
774 AllowedTypes expectedC;
775 SmallVector<Type> expectedResult;
776
777 // When M = 16, we just need to calculate the number of 8xk tiles, where
778 // k is a factor that depends on the data type.
779 if (mmaShape[0] == 16) {
780 int64_t kFactor;
781 Type multiplicandFragType;
782 switch (*getMultiplicandAPtxType()) {
783 case MMATypes::tf32:
784 kFactor = 4;
785 multiplicandFragType = i32Ty;
786 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
787 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
788 break;
789 case MMATypes::bf16:
790 kFactor = 8;
791 multiplicandFragType = i32Ty;
792 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
793 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
794 break;
795 case MMATypes::f16:
796 kFactor = 8;
797 multiplicandFragType = f16x2Ty;
798 expectedResult.push_back(f16x2x2StructTy);
799 expectedResult.push_back(f32x4StructTy);
800 break;
801 case MMATypes::s4:
802 case MMATypes::u4:
803 kFactor = 32;
804 break;
805 case MMATypes::b1:
806 kFactor = 128;
807 break;
808 case MMATypes::s8:
809 case MMATypes::u8:
810 kFactor = 16;
811 break;
812 default:
813 return emitError("invalid shape or multiplicand type: " +
814 stringifyEnum(getMultiplicandAPtxType().value()));
815 }
816
817 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
818 expectedResult.push_back(s32x4StructTy);
819 expectedC.emplace_back(4, i32Ty);
820 multiplicandFragType = i32Ty;
821 } else {
822 expectedC.emplace_back(2, f16x2Ty);
823 expectedC.emplace_back(4, f32Ty);
824 }
825
826 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
827 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
828 expectedA.emplace_back(unitA, multiplicandFragType);
829 expectedB.emplace_back(unitB, multiplicandFragType);
830 allowedShapes.push_back({16, 8, kFactor});
831 allowedShapes.push_back({16, 8, kFactor * 2});
832
833 if (resultPtxType() != accumPtxType())
834 return emitOpError("ctype does not match dtype");
835 }
836
837 // In the M=8 case, there is only 1 possible case per data type.
838 if (mmaShape[0] == 8) {
839 if (*getMultiplicandAPtxType() == MMATypes::f16) {
840 expectedA.emplace_back(2, f16x2Ty);
841 expectedB.emplace_back(2, f16x2Ty);
842 expectedResult.push_back(f16x2x4StructTy);
843 expectedResult.push_back(f32x8StructTy);
844 expectedC.emplace_back(4, f16x2Ty);
845 expectedC.emplace_back(8, f32Ty);
846 allowedShapes.push_back({8, 8, 4});
847 }
848 if (*getMultiplicandAPtxType() == MMATypes::f64) {
849 Type f64Ty = Float64Type::get(context);
850 expectedA.emplace_back(1, f64Ty);
851 expectedB.emplace_back(1, f64Ty);
852 expectedC.emplace_back(2, f64Ty);
853 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
854 context, SmallVector<Type>(2, f64Ty)));
855 allowedShapes.push_back({8, 8, 4});
856 }
857 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
858 expectedA.push_back({i32Ty});
859 expectedB.push_back({i32Ty});
860 expectedC.push_back({i32Ty, i32Ty});
861 expectedResult.push_back(s32x2StructTy);
862 if (isInt4PtxType(getMultiplicandAPtxType().value()))
863 allowedShapes.push_back({8, 8, 32});
864 if (isInt8PtxType(getMultiplicandAPtxType().value()))
865 allowedShapes.push_back({8, 8, 16});
866 if (getMultiplicandAPtxType().value() == MMATypes::b1)
867 allowedShapes.push_back({8, 8, 128});
868 }
869 }
870
871 std::string errorMessage;
872 llvm::raw_string_ostream errorStream(errorMessage);
873
874 // Check that we matched an existing shape/dtype combination.
875 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
876 !llvm::is_contained(allowedShapes, mmaShape)) {
877 errorStream << "unimplemented variant for MMA shape <";
878 llvm::interleaveComma(mmaShape, errorStream);
879 errorStream << ">";
880 return emitOpError(errorMessage);
881 }
882
883 // Verify the operand types for segments of A, B, and C operands.
884 std::array<StringRef, 3> operandNames{"A", "B", "C"};
885 for (const auto &iter : llvm::enumerate(
886 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
887 auto spec = this->getODSOperandIndexAndLength(iter.index());
888 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
889 operand_type_begin() + spec.first +
890 spec.second);
891 bool match = llvm::is_contained(iter.value(), operandTySeg);
892
893 if (!match) {
894 errorStream << "Could not match types for the "
895 << operandNames[iter.index()]
896 << " operands; expected one of ";
897 for (const auto &x : iter.value()) {
898 errorStream << x.size() << "x" << x[0] << " ";
899 }
900 errorStream << "but got ";
901 llvm::interleaveComma(operandTySeg, errorStream);
902 return emitOpError(errorMessage);
903 }
904 }
905
906 // Check the result type
907 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
908 return expectedResultType == getResult().getType();
909 })) {
910 errorStream
911 << "Could not match allowed types for the result; expected one of ";
912 llvm::interleaveComma(expectedResult, errorStream);
913 errorStream << " but got " << getResult().getType();
914 return emitOpError(errorMessage);
915 }
916
917 // Ensure that binary MMA variants have a b1 MMA operation defined.
918 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
919 return emitOpError("op requires " + getB1OpAttrName().strref() +
920 " attribute");
921 }
922
923 // Ensure int4/int8 MMA variants specify the accum overflow behavior
924 // attribute.
925 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
926 isInt8PtxType(*getMultiplicandAPtxType())) {
927 if (!getIntOverflowBehavior())
928 return emitOpError("op requires " +
929 getIntOverflowBehaviorAttrName().strref() +
930 " attribute");
931 }
932
933 // Validate layout combinations. According to the operation description, most
934 // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
935 // can use other layout combinations.
936 bool isM8N8K4_F16 =
937 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
938 getMultiplicandAPtxType() == MMATypes::f16);
939
940 if (!isM8N8K4_F16) {
941 // For all other shapes/types, layoutA must be row and layoutB must be col
942 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
943 return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
944 "layoutB = #nvvm.mma_layout<col> for shape <")
945 << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
946 << "> with element types "
947 << stringifyEnum(*getMultiplicandAPtxType()) << " and "
948 << stringifyEnum(*getMultiplicandBPtxType())
949 << ". Only m8n8k4 with f16 supports other layouts.";
950 }
951 }
952
953 return success();
954}
955
956LogicalResult ShflOp::verify() {
957 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
958
959 auto verifyTypeError = [&](Twine desc, Type expectedType,
960 Type actualType) -> LogicalResult {
961 return emitOpError("expected " + desc + " to be of type ")
962 << expectedType << " but got " << actualType << " instead";
963 };
964
965 if (returnStructType) {
966 if (!getReturnValueAndIsValid())
967 return emitOpError("\"return_value_and_is_valid\" attribute must be "
968 "specified when the return type is a struct type");
969
970 if (returnStructType.getBody().size() != 2)
971 return emitOpError("expected return type to be a two-element struct");
972
973 llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
974 auto resultType = returnStruct[0];
975 if (resultType != getVal().getType())
976 return verifyTypeError("first element in the returned struct",
977 getVal().getType(), resultType);
978
979 auto predicateType = returnStruct[1];
980 if (!predicateType.isInteger(1))
981 return verifyTypeError("second element in the returned struct",
982 mlir::IntegerType::get(getContext(), 1),
983 predicateType);
984 } else {
985 if (getReturnValueAndIsValid())
986 return emitOpError("expected return type to be a two-element struct");
987
988 if (getType() != getVal().getType())
989 return verifyTypeError("return type", getVal().getType(), getType());
990 }
991 return success();
992}
993
994std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
995 NVVM::MMAFrag frag, int nRow,
996 int nCol,
997 MLIRContext *context) {
998 unsigned numberElements = 0;
999 Type elementType;
1000 OpBuilder builder(context);
1001 Type f16x2 = VectorType::get(2, builder.getF16Type());
1002 if (type == NVVM::MMATypes::f16) {
1003 elementType = f16x2;
1004 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
1005 numberElements = 8;
1006 else
1007 numberElements = 4;
1008 } else if (type == NVVM::MMATypes::f32) {
1009 elementType = builder.getF32Type();
1010 numberElements = 8;
1011 } else if (type == NVVM::MMATypes::f64) {
1012 elementType = builder.getF64Type();
1013 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
1014 numberElements = 1;
1015 else
1016 numberElements = 2;
1017 } else if (type == NVVM::MMATypes::tf32) {
1018 elementType = builder.getI32Type();
1019 numberElements = 4;
1020 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
1021 elementType = builder.getI32Type();
1022 int parallelSize = 0;
1023 if (frag == NVVM::MMAFrag::a)
1024 parallelSize = nRow;
1025 if (frag == NVVM::MMAFrag::b)
1026 parallelSize = nCol;
1027
1028 // m == 16 && n == 16 && k == 16
1029 if (parallelSize == 16)
1030 numberElements = 2;
1031 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
1032 else if (parallelSize == 8)
1033 numberElements = 1;
1034 else if (parallelSize == 32)
1035 numberElements = 4;
1036 } else if (type == NVVM::MMATypes::s32) {
1037 elementType = builder.getI32Type();
1038 numberElements = 8;
1039 }
1040 assert(numberElements != 0 && elementType != nullptr);
1041 return std::make_pair(elementType, numberElements);
1042}
1043
1044static std::pair<mlir::Type, unsigned>
1045inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
1046 int k, MLIRContext *context) {
1047 int nRow, nCol;
1048 if (frag == NVVM::MMAFrag::a) {
1049 nRow = m;
1050 nCol = k;
1051 } else if (frag == NVVM::MMAFrag::b) {
1052 nRow = k;
1053 nCol = n;
1054 } else {
1055 nRow = m;
1056 nCol = n;
1057 }
1058 assert(nRow && nCol);
1059 return inferMMAType(type, frag, nRow, nCol, context);
1060}
1061
1062LogicalResult NVVM::WMMALoadOp::verify() {
1063 unsigned addressSpace =
1064 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
1065 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
1066 addressSpace != NVVMMemorySpace::Shared)
1067 return emitOpError("expected source pointer in memory "
1068 "space 0, 1, 3");
1069
1070 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
1071 getEltype(), getFrag()) == 0)
1072 return emitOpError() << "invalid attribute combination";
1073 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
1074 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
1075 // Special case for f64 fragments
1076 Type f64Ty = Float64Type::get(getContext());
1077 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
1078 if (getType() != f64Ty)
1079 return emitOpError("expected destination type to be f64");
1080 return success();
1081 }
1082 // Everything else is a struct
1083 Type dstType = LLVM::LLVMStructType::getLiteral(
1084 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
1085 if (getType() != dstType)
1086 return emitOpError("expected destination type is a structure of ")
1087 << typeInfo.second << " elements of type " << typeInfo.first;
1088 return success();
1089}
1090
1091LogicalResult NVVM::WMMAStoreOp::verify() {
1092 unsigned addressSpace =
1093 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
1094 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
1095 addressSpace != NVVMMemorySpace::Shared)
1096 return emitOpError("expected operands to be a source pointer in memory "
1097 "space 0, 1, 3");
1098
1099 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
1100 getEltype()) == 0)
1101 return emitOpError() << "invalid attribute combination";
1102 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
1103 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
1104 if (getArgs().size() != typeInfo.second)
1105 return emitOpError() << "expected " << typeInfo.second << " data operands";
1106 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
1107 return operands.getType() != typeInfo.first;
1108 }))
1109 return emitOpError() << "expected data operands of type " << typeInfo.first;
1110 return success();
1111}
1112
1113LogicalResult NVVM::WMMAMmaOp::verify() {
1114 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
1115 getLayoutB(), getEltypeA(),
1116 getEltypeB()) == 0)
1117 return emitOpError() << "invalid attribute combination";
1118 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
1119 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
1120 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
1121 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
1122 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
1123 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
1124 SmallVector<Type, 32> arguments;
1125 arguments.append(typeInfoA.second, typeInfoA.first);
1126 arguments.append(typeInfoB.second, typeInfoB.first);
1127 arguments.append(typeInfoC.second, typeInfoC.first);
1128 unsigned numArgs = arguments.size();
1129 if (getArgs().size() != numArgs)
1130 return emitOpError() << "expected " << numArgs << " arguments";
1131 for (unsigned i = 0; i < numArgs; i++) {
1132 if (getArgs()[i].getType() != arguments[i])
1133 return emitOpError() << "expected argument " << i << " to be of type "
1134 << arguments[i];
1135 }
1136 Type dstType = LLVM::LLVMStructType::getLiteral(
1137 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
1138 if (getType() != dstType)
1139 return emitOpError("expected destination type is a structure of ")
1140 << typeInfoC.second << " elements of type " << typeInfoC.first;
1141 return success();
1142}
1143
1144LogicalResult NVVM::LdMatrixOp::verify() {
1145 uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
1146 if (m == 8 && n == 8) {
1147 if (num != 1 && num != 2 && num != 4) {
1148 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
1149 "matrix");
1150 }
1151 if (getEltType() != LdStMatrixEltType::B16) {
1152 return emitOpError("expected element type to be b16 for 8x8 matrix");
1153 }
1154 } else if (m == 8 && n == 16) {
1155 if (num != 1 && num != 2 && num != 4) {
1156 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
1157 "matrix");
1158 }
1159 if (getLayout() != MMALayout::row) {
1160 return emitOpError("expected layout to be row for 8x16 matrix");
1161 }
1162 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1163 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1164 return emitOpError("expected element type to be b8x16.b4x16_p64 or "
1165 "b8x16.b6x16_p32 for 8x16 matrix");
1166 }
1167 } else if (m == 16 && n == 16) {
1168 if (num != 1 && num != 2) {
1169 return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
1170 "matrix");
1171 }
1172 if (getLayout() != MMALayout::col) {
1173 return emitOpError("expected layout to be col for 16x16 matrix");
1174 }
1175 if (getEltType() != LdStMatrixEltType::B8 &&
1176 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1177 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1178 return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
1179 "b8x16.b6x16_p32 for 16x16 matrix");
1180 }
1181 } else {
1182 return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
1183 }
1184
1185 Type i32 = IntegerType::get(getContext(), 32);
1186 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
1187 if (numElements == 1 && getType() != i32)
1188 return emitOpError("expected destination type is i32");
1189 if (numElements == 2 || numElements == 4) {
1190 Type dstType = LLVM::LLVMStructType::getLiteral(
1191 getContext(), SmallVector<Type>(numElements, i32));
1192 if (getType() != dstType)
1193 return emitOpError("expected destination type is a structure of ")
1194 << numElements << " elements of type i32";
1195 }
1196
1197 return success();
1198}
1199
1200LogicalResult NVVM::StMatrixOp::verify() {
1201 int numMatrix = getSources().size();
1202 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
1203 return emitOpError("expected num attribute to be 1, 2 or 4");
1204
1205 int m = getShape().getM(), n = getShape().getN();
1206 if (m == 8 && n == 8) {
1207 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
1208 return emitOpError("expected element type to be B16 for 8x8 matrix");
1209 }
1210 } else if (m == 16 && n == 8) {
1211 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
1212 return emitOpError("expected element type to be B8 for 16x8 matrix");
1213 }
1214 if (getLayout() != NVVM::MMALayout::col) {
1215 return emitOpError("expected layout to be col for 16x8 matrix");
1216 }
1217 } else {
1218 return emitOpError("expected shape to be 8x8 or 16x8");
1219 }
1220
1221 return success();
1222}
1223
1224static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
1225 if (typeA == NVVM::WGMMATypes::tf32)
1226 return 8;
1227 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
1228 return 16;
1229 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
1230 return 32;
1231 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
1232 return 32;
1233 if (typeA == NVVM::WGMMATypes::b1)
1234 return 256;
1235 return failure();
1236}
1237
1238static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
1239 NVVM::WGMMATypes typeA,
1240 NVVM::WGMMATypes typeB) {
1241 switch (typeA) {
1242 case NVVM::WGMMATypes::f16:
1243 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1244 typeB == NVVM::WGMMATypes::f16)
1245 return success();
1246 break;
1247 case NVVM::WGMMATypes::tf32:
1248 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
1249 return success();
1250 break;
1251 case NVVM::WGMMATypes::u8:
1252 case NVVM::WGMMATypes::s8:
1253 if (typeD == NVVM::WGMMATypes::s32 &&
1254 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
1255 return success();
1256 break;
1257 case NVVM::WGMMATypes::b1:
1258 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
1259 return success();
1260 break;
1261 case NVVM::WGMMATypes::bf16:
1262 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1263 typeB == NVVM::WGMMATypes::bf16)
1264 return success();
1265 break;
1266 case NVVM::WGMMATypes::e4m3:
1267 case NVVM::WGMMATypes::e5m2:
1268 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1269 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1270 return success();
1271 break;
1272 case WGMMATypes::f32:
1273 case WGMMATypes::s32:
1274 llvm_unreachable("unsupported input types");
1275 break;
1276 }
1277 return failure();
1278}
1279
1280static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
1281 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
1282 72, 80, 88, 96, 104, 112, 120, 128,
1283 136, 144, 152, 160, 168, 176, 184, 192,
1284 200, 208, 216, 224, 232, 240, 248, 256};
1285 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
1286 80, 96, 112, 128, 144, 160,
1287 176, 192, 208, 224, 240, 256};
1288 switch (typeA) {
1289 case WGMMATypes::f16:
1290 case WGMMATypes::tf32:
1291 case WGMMATypes::bf16:
1292 case WGMMATypes::e4m3:
1293 case WGMMATypes::e5m2:
1294 if (llvm::is_contained(allowedN, sizeN))
1295 return success();
1296 break;
1297 case WGMMATypes::u8:
1298 case WGMMATypes::s8:
1299 case WGMMATypes::b1:
1300 if (llvm::is_contained(allowedNshort, sizeN))
1301 return success();
1302 break;
1303 case WGMMATypes::f32:
1304 case WGMMATypes::s32:
1305 llvm_unreachable("unsupported input types");
1306 break;
1307 }
1308 return failure();
1309}
1310
1311LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
1312 Value outValue = getResults();
1313 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
1314 if (!stype)
1315 return emitOpError() << "expected results to be struct";
1316 int outputSize = stype.getBody().size();
1317 WGMMATypes typeD = getTypeD();
1318 WGMMATypes typeA = getTypeA();
1319 WGMMATypes typeB = getTypeB();
1320
1321 for (Type t : stype.getBody()) {
1322 if (t != stype.getBody().front())
1323 return emitOpError()
1324 << "all elements in struct must be same type but there is " << t;
1325 }
1326
1327 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1328 typeD != WGMMATypes::s32) {
1329 return emitOpError() << "does not support the given output type "
1330 << NVVM::stringifyWGMMATypes(typeD);
1331 }
1332 if (typeD == WGMMATypes::s32 &&
1333 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1334 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
1335 }
1336
1337 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
1338 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1339 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
1340 << NVVM::stringifyWGMMATypes(typeB)
1341 << ", it is not supported.";
1342 }
1343
1344 // Check M
1345 if (getShape().getM() != 64)
1346 return emitOpError() << "shape 'm' must be 64";
1347
1348 // Check K
1349 FailureOr<int> allowedK = getAllowedSizeK(typeA);
1350 if (failed(allowedK) || allowedK.value() != getShape().getK())
1351 return emitOpError() << "shape 'k' must be " << allowedK.value()
1352 << " for input type "
1353 << NVVM::stringifyWGMMATypes(typeA);
1354
1355 // Check N
1356 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
1357 return emitOpError() << "has input type "
1358 << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
1359 << getShape().getN() << ", it is not supported.";
1360 }
1361
1362 // Check transpose (only available for f16/bf16)
1363 // Matrices A should be stored in row-major and B in column-major.
1364 // Only f16/bf16 matrices can be stored in either column-major or row-major
1365 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
1366 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1367 (getLayoutA() == mlir::NVVM::MMALayout::col ||
1368 getLayoutB() == mlir::NVVM::MMALayout::row)) {
1369 return emitOpError()
1370 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1371 << " and layout_b = " << stringifyMMALayout(getLayoutB())
1372 << " for input types " << stringifyWGMMATypes(typeA) << " and "
1373 << stringifyWGMMATypes(typeB)
1374 << " requires transpose. However, this is only supported for: "
1375 << stringifyMMATypes(MMATypes::f16) << " and "
1376 << stringifyMMATypes(MMATypes::bf16);
1377 }
1378
1379 // Check result registers
1380 int expectedOutput = 0;
1381 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1382 expectedOutput = getShape().getN() / 2;
1383 if (typeD == WGMMATypes::f16)
1384 expectedOutput = getShape().getN() / 4;
1385 if (outputSize != expectedOutput) {
1386 return emitOpError() << "results " << expectedOutput
1387 << ", however output struct has " << outputSize
1388 << " elements";
1389 }
1390 // Check satfinite (only available for s32 accumulator)
1391 if (typeD != WGMMATypes::s32 &&
1392 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1393 NVVM::MMAIntOverflow::satfinite) {
1394 return emitOpError()
1395 << " `satfinite` can be only used with s32 accumulator, however "
1396 "the current accumulator is "
1397 << NVVM::stringifyWGMMATypes(typeD);
1398 }
1399
1400 return success();
1401}
1402
1403std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1404
1405 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1406 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1407
1408 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1409
1410 int expectedOutputRegisters = 0;
1411 if (getTypeD() == WGMMATypes::f16)
1412 expectedOutputRegisters = getShape().getN() / 4;
1413 else
1414 expectedOutputRegisters = getShape().getN() / 2;
1415
1416 std::string ptx;
1417 llvm::raw_string_ostream ss(ptx);
1418
1419 ss << "{\n"
1420 ".reg .pred p;\n"
1421 "setp.ne.b32 p, $"
1422 << ((expectedOutputRegisters * 2) + 2)
1423 << ", 0;\n"
1424 "wgmma.mma_async.sync.aligned.m"
1425 << m << "n" << n << "k" << k << "." << outputTypeName << "."
1426 << stringifyWGMMATypes(getTypeA()) << "."
1427 << stringifyWGMMATypes(getTypeB());
1428 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1429 NVVM::MMAIntOverflow::satfinite)
1430 ss << ".satfinite";
1431 ss << " {";
1432 int regCnt = 0;
1433 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1434 ss << "$" << regCnt;
1435 if (regCnt != expectedOutputRegisters - 1)
1436 ss << ", ";
1437 }
1438
1439 ss << "},";
1440 // Need to map read/write registers correctly.
1441 regCnt = (regCnt * 2);
1442 ss << " $" << (regCnt) << ","
1443 << " $" << (regCnt + 1) << ","
1444 << " p";
1445 if (getTypeD() != WGMMATypes::s32) {
1446 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1447 }
1448 // Don't add transpose parameters unless needed.
1449 if (isF16) {
1450 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1451 }
1452 ss << ";\n"
1453 << "}\n";
1454 return ptx;
1455}
1456
1457bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1458 RewriterBase &rewriter,
1459 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1460 &asmValues) {
1461 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1462 if (getResults())
1463 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1464 if (getInouts())
1465 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1466 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1467 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1468 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1470 if (getTypeD() != WGMMATypes::s32) {
1471 asmValues.push_back(
1472 {makeConstantI32(rewriter,
1473 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1475 asmValues.push_back(
1476 {makeConstantI32(rewriter,
1477 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1479 }
1480 if (isF16) {
1481 asmValues.push_back(
1482 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1484 asmValues.push_back(
1485 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1487 }
1488 return true; // Has manual mapping
1489}
1490
1491LogicalResult NVVM::FenceProxyOp::verify() {
1492 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1493 return emitOpError() << "tensormap proxy is not a supported proxy kind";
1494 if (getKind() == NVVM::ProxyKind::GENERIC)
1495 return emitOpError() << "generic proxy not a supported proxy kind";
1496 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1497 return emitOpError() << "async_shared fence requires space attribute";
1498 }
1499 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1500 return emitOpError() << "only async_shared fence can have space attribute";
1501 }
1502 return success();
1503}
1504
1505LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1506 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1507 return emitOpError("uni-directional proxies only support generic for "
1508 "from_proxy attribute");
1509
1510 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1511 return emitOpError("uni-directional proxies only support tensormap "
1512 "for to_proxy attribute");
1513
1514 return success();
1515}
1516
1517LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1518 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1519 return emitOpError("uni-directional proxies only support generic for "
1520 "from_proxy attribute");
1521
1522 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1523 return emitOpError("uni-directional proxies only support tensormap "
1524 "for to_proxy attribute");
1525
1526 return success();
1527}
1528
1529LogicalResult NVVM::SetMaxRegisterOp::verify() {
1530 if (getRegCount() % 8)
1531 return emitOpError("new register size must be multiple of 8");
1532 if (getRegCount() < 24 || getRegCount() > 256)
1533 return emitOpError("new register size must be in between 24 to 256");
1534 return success();
1535}
1536
1537LogicalResult NVVM::BarrierOp::verify() {
1538 if (getNumberOfThreads() && !getBarrierId())
1539 return emitOpError(
1540 "barrier id is missing, it should be set between 0 to 15");
1541
1542 if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
1543 return emitOpError("reduction are only available when id is 0");
1544
1545 if ((getReductionOp() && !getReductionPredicate()) ||
1546 (!getReductionOp() && getReductionPredicate()))
1547 return emitOpError("reduction predicate and reduction operation must be "
1548 "specified together");
1549
1550 return success();
1551}
1552
1553LogicalResult NVVM::Tcgen05CpOp::verify() {
1554 auto mc = getMulticast();
1555
1556 using SH = Tcgen05CpShape;
1557 using MC = Tcgen05CpMulticast;
1558 switch (getShape()) {
1559 case SH::SHAPE_128x256b:
1560 case SH::SHAPE_128x128b:
1561 case SH::SHAPE_4x256b:
1562 if (mc != MC::NONE)
1563 return emitError("Invalid multicast type for tcgen05.cp Op");
1564 break;
1565 case SH::SHAPE_64x128b:
1566 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1567 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1568 "warpx2_02_13 for tcgen05.cp Op");
1569 break;
1570 case SH::SHAPE_32x128b:
1571 if (mc != MC::WARPX4)
1572 return emitError(
1573 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1574 break;
1575 }
1576 return success();
1577}
1578
1579LogicalResult NVVM::MatchSyncOp::verify() {
1580 if (getKind() == NVVM::MatchSyncKind::all) {
1581 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1582 if (!type || type.getBody().size() != 2 ||
1583 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1584 return emitOpError("match.sync 'all' returns a two element struct with "
1585 "first element as i32 and second element as i1");
1586 }
1587 } else {
1588 if (!getType().isInteger(32)) {
1589 return emitOpError("match.sync 'any' returns an i32");
1590 }
1591 }
1592 return success();
1593}
1594
1595LogicalResult NVVM::VoteSyncOp::verify() {
1596 if (getKind() == NVVM::VoteSyncKind::ballot) {
1597 if (!getType().isInteger(32)) {
1598 return emitOpError("vote.sync 'ballot' returns an i32");
1599 }
1600 } else {
1601 if (!getType().isInteger(1)) {
1602 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1603 }
1604 }
1605 return success();
1606}
1607
1608LogicalResult NVVM::PrefetchOp::verify() {
1609 using MemSpace = NVVM::NVVMMemorySpace;
1610 using CacheLevel = NVVM::PrefetchCacheLevel;
1611
1612 unsigned addressSpace =
1613 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1614 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1615 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1616
1617 if (getTensormap() && cacheLevel)
1618 return emitOpError("cannot specify both tensormap and cache level");
1619
1620 if (getTensormap()) {
1621 if (addressSpace != MemSpace::Generic &&
1622 addressSpace != MemSpace::Constant) {
1623 return emitOpError(
1624 "prefetch tensormap requires a generic or constant pointer");
1625 }
1626
1627 if (evictPriority) {
1628 return emitOpError(
1629 "prefetch tensormap does not support eviction priority");
1630 }
1631
1632 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1633 return emitOpError(
1634 "in_param_space can only be specified for a generic pointer");
1635 }
1636
1637 } else if (cacheLevel) {
1638 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1639 addressSpace != MemSpace::Local) {
1640 return emitOpError("prefetch to cache level requires a generic, global, "
1641 "or local pointer");
1642 }
1643
1644 if (getUniform()) {
1645 if (*cacheLevel != CacheLevel::L1) {
1646 return emitOpError(
1647 "unsupported cache level, the only supported uniform "
1648 "cache level is L1");
1649 }
1650
1651 if (addressSpace != MemSpace::Generic) {
1652 return emitOpError(
1653 "prefetch to uniform cache requires a generic pointer");
1654 }
1655 }
1656
1657 if (evictPriority) {
1658 if (*cacheLevel != CacheLevel::L2)
1659 return emitOpError(
1660 "cache eviction priority supported only for cache level L2");
1661
1662 if (addressSpace != MemSpace::Global)
1663 return emitOpError("cache eviction priority requires a global pointer");
1664
1665 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1666 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1667 return emitOpError(
1668 "unsupported cache eviction priority, only evict_last and "
1669 "evict_normal are supported");
1670 }
1671
1672 if (getPredicate())
1673 return emitOpError("predicate supported only on prefetch tensormap");
1674
1675 } else {
1676 return emitOpError(
1677 "requires specification of either cache level or tensormap");
1678 }
1679
1680 return success();
1681}
1682
1683LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
1684 switch (getQueryType()) {
1685 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1686 if (!getType().isInteger(1))
1687 return emitOpError("is_canceled query type returns an i1");
1688 break;
1689 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1690 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1691 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1692 if (!getType().isInteger(32)) {
1693 return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
1694 "get_first_cta_id_z query types return an i32");
1695 }
1696 break;
1697 }
1698 return success();
1699}
1700
1701LogicalResult NVVM::ReduxOp::verify() {
1702 mlir::Type reduxType = getType();
1703
1704 if (!reduxType.isF32()) {
1705 if (getAbs())
1706 return emitOpError("abs attribute is supported only for f32 type");
1707 if (getNan())
1708 return emitOpError("nan attribute is supported only for f32 type");
1709 }
1710
1711 NVVM::ReduxKind kind = getKind();
1712 switch (kind) {
1713 case NVVM::ReduxKind::ADD:
1714 case NVVM::ReduxKind::AND:
1715 case NVVM::ReduxKind::OR:
1716 case NVVM::ReduxKind::XOR:
1717 case NVVM::ReduxKind::MAX:
1718 case NVVM::ReduxKind::MIN:
1719 case NVVM::ReduxKind::UMAX:
1720 case NVVM::ReduxKind::UMIN:
1721 if (!reduxType.isInteger(32))
1722 return emitOpError("'")
1723 << stringifyEnum(kind) << "' redux kind unsupported with "
1724 << reduxType << " type. Only supported type is 'i32'.";
1725 break;
1726 case NVVM::ReduxKind::FMIN:
1727 case NVVM::ReduxKind::FMAX:
1728 if (!reduxType.isF32())
1729 return emitOpError("'")
1730 << stringifyEnum(kind) << "' redux kind unsupported with "
1731 << reduxType << " type. Only supported type is 'f32'.";
1732 break;
1733 }
1734
1735 return success();
1736}
1737
1738/// Packs the given `field` into the `result`.
1739/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1740static llvm::Value *
1741packValInto64Bits(llvm::IRBuilderBase &builder,
1742 llvm::Value *result, // the `result` (unset bits are zero)
1743 llvm::Value *field, // `field` to pack into `result`
1744 unsigned sizeInBits, // Size of `field` in bits
1745 unsigned start) { // Starting bit within `result`
1746 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1747
1748 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1749 if (mask != 0xffffffffu)
1750 field = builder.CreateAnd(field, builder.getInt32(mask));
1751
1752 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1753 field = builder.CreateShl(field, start);
1754
1755 return builder.CreateOr(result, field);
1756}
1757
1758void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1760 llvm::IRBuilderBase &builder) {
1761 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1762 llvm::Value *smemDesc = builder.getInt64(0);
1763
1764 smemDesc = packValInto64Bits(builder, smemDesc,
1765 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1766 smemDesc = packValInto64Bits(
1767 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1768 smemDesc = packValInto64Bits(
1769 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1770
1771 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1772 smemDesc = packValInto64Bits(builder, smemDesc,
1773 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1774 smemDesc = packValInto64Bits(
1775 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1776 smemDesc = packValInto64Bits(builder, smemDesc,
1777 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1778
1779 mt.mapValue(thisOp.getRes()) = smemDesc;
1780}
1781
1782//===----------------------------------------------------------------------===//
1783// getPtx methods
1784//===----------------------------------------------------------------------===//
1785
1786std::string NVVM::MBarrierInitOp::getPtx() {
1787 bool isShared = isPtrInSharedCTASpace(getAddr());
1788 return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
1789 : std::string("mbarrier.init.b64 [%0], %1;");
1790}
1791
1792std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
1793 bool isShared = isPtrInSharedCTASpace(getAddr());
1794 return isShared
1795 ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
1796 : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
1797}
1798
1799std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
1800 bool isShared = isPtrInSharedCTASpace(getAddr());
1801 llvm::StringRef space = isShared ? ".shared" : "";
1802
1803 return llvm::formatv("{\n\t"
1804 ".reg .pred P1; \n\t"
1805 "LAB_WAIT: \n\t"
1806 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
1807 "@P1 bra.uni DONE; \n\t"
1808 "bra.uni LAB_WAIT; \n\t"
1809 "DONE: \n\t"
1810 "}",
1811 space);
1812}
1813
1814//===----------------------------------------------------------------------===//
1815// getIntrinsicID/getIntrinsicIDAndArgs methods
1816//===----------------------------------------------------------------------===//
1817
1818mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
1819 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1820 auto thisOp = cast<NVVM::BarrierOp>(op);
1821 llvm::Value *barrierId = thisOp.getBarrierId()
1822 ? mt.lookupValue(thisOp.getBarrierId())
1823 : builder.getInt32(0);
1824 llvm::Intrinsic::ID id;
1826 if (thisOp.getNumberOfThreads()) {
1827 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
1828 args.push_back(barrierId);
1829 args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
1830 } else if (thisOp.getReductionOp()) {
1831 switch (*thisOp.getReductionOp()) {
1832 case NVVM::BarrierReduction::AND:
1833 id = llvm::Intrinsic::nvvm_barrier0_and;
1834 break;
1835 case NVVM::BarrierReduction::OR:
1836 id = llvm::Intrinsic::nvvm_barrier0_or;
1837 break;
1838 case NVVM::BarrierReduction::POPC:
1839 id = llvm::Intrinsic::nvvm_barrier0_popc;
1840 break;
1841 }
1842 args.push_back(mt.lookupValue(thisOp.getReductionPredicate()));
1843 } else {
1844 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
1845 args.push_back(barrierId);
1846 }
1847
1848 return {id, std::move(args)};
1849}
1850
1851mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
1852 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1853 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
1854 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1855 llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
1856 : llvm::Intrinsic::nvvm_mbarrier_init;
1857
1858 // Fill the Intrinsic Args
1860 args.push_back(mt.lookupValue(thisOp.getAddr()));
1861 args.push_back(mt.lookupValue(thisOp.getCount()));
1862
1863 return {id, std::move(args)};
1864}
1865
1866mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
1867 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1868 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
1869 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1870 llvm::Intrinsic::ID id = isShared
1871 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
1872 : llvm::Intrinsic::nvvm_mbarrier_inval;
1873
1874 return {id, {mt.lookupValue(thisOp.getAddr())}};
1875}
1876
1877mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
1878 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1879 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
1880 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1881 llvm::Intrinsic::ID id = isShared
1882 ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared
1883 : llvm::Intrinsic::nvvm_mbarrier_arrive;
1884
1885 return {id, {mt.lookupValue(thisOp.getAddr())}};
1886}
1887
1888mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
1889 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1890 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
1891 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1892 llvm::Intrinsic::ID id =
1893 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
1894 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
1895 // Fill the Intrinsic Args
1897 args.push_back(mt.lookupValue(thisOp.getAddr()));
1898 args.push_back(mt.lookupValue(thisOp.getCount()));
1899
1900 return {id, std::move(args)};
1901}
1902
1903mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
1904 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1905 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
1906 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1907 llvm::Intrinsic::ID id = isShared
1908 ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
1909 : llvm::Intrinsic::nvvm_mbarrier_test_wait;
1910 // Fill the Intrinsic Args
1912 args.push_back(mt.lookupValue(thisOp.getAddr()));
1913 args.push_back(mt.lookupValue(thisOp.getState()));
1914
1915 return {id, std::move(args)};
1916}
1917
1918mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
1919 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1920 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
1921 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
1922
1923 llvm::Intrinsic::ID id;
1924 if (thisOp.getNoinc()) {
1925 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
1926 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
1927 } else {
1928 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
1929 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
1930 }
1931
1932 return {id, {mt.lookupValue(thisOp.getAddr())}};
1933}
1934
1935#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1936 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1937
1938#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1939 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1940
1941llvm::Intrinsic::ID
1942CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1944 llvm::Intrinsic::ID id;
1945
1946 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1947 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1948 switch (cpAsyncOp.getSize()) {
1949 case 4:
1950 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1951 break;
1952 case 8:
1953 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1954 break;
1955 case 16:
1956 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1957 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1958 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1959 break;
1960 default:
1961 llvm_unreachable("Invalid copy size in CpAsyncOp.");
1962 }
1963
1964 // Fill the Intrinsic Args
1965 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1966 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1967 if (hasCpSize)
1968 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1969
1970 return id;
1971}
1972
1973mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1974 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1975 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1977 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1978
1979 // Fill the Intrinsic Args
1980 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1981 args.push_back(mt.lookupValue(thisOp.getSize()));
1982
1983 mlir::Value cacheHint = thisOp.getL2CacheHint();
1984 const bool hasCacheHint = static_cast<bool>(cacheHint);
1985 llvm::Value *i64Unused =
1986 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1987 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1988 args.push_back(builder.getInt1(hasCacheHint));
1989
1990 return {id, std::move(args)};
1991}
1992
1993mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1994 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1995 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1997
1998 // Fill the Intrinsic Args: dst, mbar, src, size.
1999 args.push_back(mt.lookupValue(thisOp.getDstMem()));
2000 args.push_back(mt.lookupValue(thisOp.getMbar()));
2001 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2002 args.push_back(mt.lookupValue(thisOp.getSize()));
2003
2004 // Multicast mask for shared::cluster only, if available.
2005 mlir::Value multicastMask = thisOp.getMulticastMask();
2006 const bool hasMulticastMask = static_cast<bool>(multicastMask);
2007 const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
2008 if (!isSharedCTA) {
2009 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
2010 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
2011 : i16Unused);
2012 }
2013
2014 // Cache hint, if available.
2015 mlir::Value cacheHint = thisOp.getL2CacheHint();
2016 const bool hasCacheHint = static_cast<bool>(cacheHint);
2017 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
2018 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
2019
2020 // Flag arguments for multicast and cachehint.
2021 if (!isSharedCTA)
2022 args.push_back(builder.getInt1(hasMulticastMask));
2023 args.push_back(builder.getInt1(hasCacheHint));
2024
2025 llvm::Intrinsic::ID id =
2026 isSharedCTA
2027 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
2028 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
2029
2030 return {id, std::move(args)};
2031}
2032
2033mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2034 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2035 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
2037 llvm::Intrinsic::ID id =
2038 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
2039
2040 // Fill the Intrinsic Args
2041 args.push_back(mt.lookupValue(thisOp.getDstMem()));
2042 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2043 args.push_back(mt.lookupValue(thisOp.getSize()));
2044
2045 mlir::Value cacheHint = thisOp.getL2CacheHint();
2046 const bool hasCacheHint = static_cast<bool>(cacheHint);
2047 llvm::Value *i64Unused =
2048 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
2049 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
2050 args.push_back(builder.getInt1(hasCacheHint));
2051
2052 // Choose the bytemask variant
2053 if (mlir::Value byteMask = thisOp.getByteMask()) {
2054 args.push_back(mt.lookupValue(byteMask));
2055 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
2056 }
2057
2058 return {id, std::move(args)};
2059}
2060
2061bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
2062 RewriterBase &rewriter,
2063 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2064 &asmValues) {
2065 // Add all the operands but not the attrs to the asmValues list.
2066 // The attrs here are used to generate the right variants for
2067 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
2068 for (auto val : getOperands())
2069 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
2070
2071 return false;
2072}
2073
2075CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
2076 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2077 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
2078 const bool isCTAOnly = thisOp.getIsCTAOnly();
2080
2081 // Fill the Intrinsic Args
2082 args.push_back(mt.lookupValue(thisOp.getDstMem()));
2083 args.push_back(mt.lookupValue(thisOp.getMbar()));
2084 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2085
2086 // Coordinates and im2col-offsets
2087 for (mlir::Value v : thisOp.getCoordinates())
2088 args.push_back(mt.lookupValue(v));
2089 for (mlir::Value v : thisOp.getIm2colOffsets())
2090 args.push_back(mt.lookupValue(v));
2091
2092 // MulticastMask, if available
2093 mlir::Value mcMask = thisOp.getMulticastMask();
2094 const bool hasMC = static_cast<bool>(mcMask);
2095 llvm::Value *i16Zero =
2096 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
2097
2098 // CacheHint, if available
2099 mlir::Value cacheHint = thisOp.getL2CacheHint();
2100 const bool hasCacheHint = static_cast<bool>(cacheHint);
2101 llvm::Value *i64Zero =
2102 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
2103
2104 // Flag argument CTAGroup
2105 // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
2106 // Hence, the +1 to getGroup().
2107 const int32_t val =
2108 thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
2109 llvm::Value *cg =
2110 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
2111
2112 if (!isCTAOnly) {
2113 // For shared::cluster, all the arguments that we build are applicable.
2114 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
2115 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
2116 args.push_back(builder.getInt1(hasMC));
2117 args.push_back(builder.getInt1(hasCacheHint));
2118 args.push_back(cg);
2119 } else {
2120 // For shared::cta, only cache-hint is applicable.
2121 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
2122 args.push_back(builder.getInt1(hasCacheHint));
2123 }
2124
2125 constexpr size_t numDims = 5; // 1D to 5D
2126 constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
2127 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
2128 using TableTy = std::array<rowTy, numModes>;
2129 static constexpr TableTy IDTable{
2130 {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
2131 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
2132 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
2133 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
2134 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
2136 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
2137 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
2138 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
2140 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
2141 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
2142 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
2144 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
2145 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
2146 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
2148 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
2149
2150 static constexpr TableTy IDTableCTA{
2151 {{notIntrinsic,
2152 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
2153 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
2154 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
2155 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
2156 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
2158 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
2159 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
2160 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
2162 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
2163 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
2164 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
2166 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
2167 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
2168 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
2170 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
2171
2172 static_assert(
2173 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
2174 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
2175 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
2176 size_t mode = static_cast<size_t>(thisOp.getMode());
2177 size_t dim = thisOp.getCoordinates().size();
2178 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
2179 assert(id != notIntrinsic &&
2180 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
2181
2182 return {id, std::move(args)};
2183}
2184
2185mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
2186 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2187 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
2189
2190 // Fill the Intrinsic Args
2191 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2192
2193 for (auto v : thisOp.getCoordinates())
2194 args.push_back(mt.lookupValue(v));
2195 for (auto v : thisOp.getIm2colOffsets())
2196 args.push_back(mt.lookupValue(v));
2197
2198 mlir::Value cacheHint = thisOp.getL2CacheHint();
2199 const bool hasCacheHint = static_cast<bool>(cacheHint);
2200 llvm::Value *i64Unused =
2201 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
2202 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
2203 args.push_back(builder.getInt1(hasCacheHint));
2204
2205 const unsigned NI = llvm::Intrinsic::not_intrinsic;
2206 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
2207 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
2208 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
2209 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
2210 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
2211 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
2212 {NI, NI, NI,
2213 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
2214 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
2215 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
2216 {NI, NI, NI,
2217 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
2218 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
2219 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
2220 {NI, NI, NI,
2221 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
2222 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
2223 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
2224 {NI, NI, NI, NI, NI,
2225 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
2226
2227 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
2228 "TMALoadModes must match number of rows in IDTable");
2229 size_t mode = static_cast<size_t>(thisOp.getMode());
2230 size_t dim = thisOp.getCoordinates().size();
2231 llvm::Intrinsic::ID id = IDTable[mode][dim];
2232 if (id == llvm::Intrinsic::not_intrinsic)
2233 llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
2234
2235 return {id, std::move(args)};
2236}
2237
2239CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2240 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2241 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
2243
2244 // Fill the Intrinsic Args
2245 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2246 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2247
2248 for (auto v : thisOp.getCoordinates())
2249 args.push_back(mt.lookupValue(v));
2250
2251 mlir::Value cacheHint = thisOp.getL2CacheHint();
2252 const bool hasCacheHint = static_cast<bool>(cacheHint);
2253 llvm::Value *i64Unused =
2254 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
2255 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
2256 args.push_back(builder.getInt1(hasCacheHint));
2257
2258 const unsigned NI = llvm::Intrinsic::not_intrinsic;
2259 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
2260 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
2261 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
2262 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
2263 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
2264 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
2265 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
2266 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
2267 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
2268 {NI, NI, NI, NI, NI,
2269 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
2270
2271 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
2272 "TMAStoreModes must match number of rows in IDTable");
2273 size_t mode = static_cast<size_t>(thisOp.getMode());
2274 size_t dim = thisOp.getCoordinates().size();
2275 llvm::Intrinsic::ID id = IDTable[mode][dim];
2276 if (id == llvm::Intrinsic::not_intrinsic)
2277 llvm_unreachable(
2278 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
2279
2280 return {id, std::move(args)};
2281}
2282
2283NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
2284 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2285 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
2286 llvm::LLVMContext &ctx = mt.getLLVMContext();
2287
2289
2290 // Arguments to the intrinsic:
2291 // shared_mem_ptr, tmaDesc, tensorDims
2292 // cache_hint(if applicable) and flag(boolean)
2293 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
2294 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
2295
2296 for (Value v : thisOp.getCoordinates())
2297 args.push_back(mt.lookupValue(v));
2298
2299 mlir::Value cacheHint = thisOp.getL2CacheHint();
2300 const bool hasCacheHint = static_cast<bool>(cacheHint);
2301 llvm::Value *i64ZeroValue =
2302 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2303 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
2304 args.push_back(builder.getInt1(hasCacheHint));
2305
2306 const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
2307
2308 constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
2309 constexpr unsigned numLayouts = 2; // TILE, IM2COL
2310 constexpr unsigned maxDim = 5; // 1D to 5D
2311 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
2312 using layoutTable = std::array<row, numLayouts>;
2313 using fullTable = std::array<layoutTable, numRedKinds>;
2314 static constexpr fullTable IDTable{
2315 {// RedTy::ADD
2316 {{{{notIntrinsic,
2317 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
2318 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
2319 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
2320 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
2321 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
2323 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
2324 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
2325 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
2326 // RedTy::MIN
2327 {{{{notIntrinsic,
2328 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
2329 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
2330 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
2331 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
2332 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
2334 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
2335 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
2336 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
2337 // RedTy::MAX
2338 {{{{notIntrinsic,
2339 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
2340 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
2341 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
2342 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
2343 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
2345 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
2346 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
2347 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
2348 // RedTy::INC
2349 {{{{notIntrinsic,
2350 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
2351 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
2352 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
2353 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
2354 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
2356 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
2357 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
2358 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
2359 // RedTy::DEC
2360 {{{{notIntrinsic,
2361 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
2362 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
2363 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
2364 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
2365 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
2367 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
2368 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
2369 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
2370 // RedTy::AND
2371 {{{{notIntrinsic,
2372 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
2373 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
2374 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
2375 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
2376 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
2378 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
2379 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
2380 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
2381 // RedTy::OR
2382 {{{{notIntrinsic,
2383 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
2384 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
2385 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
2386 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
2387 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
2389 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
2390 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
2391 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
2392 // RedTy::XOR
2393 {{{{notIntrinsic,
2394 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
2395 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
2396 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
2397 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
2398 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
2400 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
2401 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
2402 llvm::Intrinsic::
2403 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
2404
2405 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
2406 "TMAReduxKinds must match number of rows in IDTable");
2407
2408 size_t redKind = static_cast<size_t>(thisOp.getRedKind());
2409 size_t mode = static_cast<size_t>(thisOp.getMode());
2410 size_t dim = thisOp.getCoordinates().size();
2411
2412 assert(redKind < IDTable.size() &&
2413 "Invalid redKind for CpAsyncBulkTensorReduceOp");
2414 assert(mode < IDTable[redKind].size() &&
2415 "Invalid mode for CpAsyncBulkTensorReduceOp");
2416 assert(dim < IDTable[redKind][mode].size() &&
2417 "Invalid dim for CpAsyncBulkTensorReduceOp");
2418
2419 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
2420
2421 assert(intrinsicID != notIntrinsic &&
2422 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
2423
2424 return {intrinsicID, std::move(args)};
2425}
2426
2427#define _none
2428
2429#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2430 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
2431 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
2432
2433#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
2434 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2435 : CVT_F2TF32_ID_IMPL(rnd, relu, )
2436
2437llvm::Intrinsic::ID
2438ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2439 NVVM::SaturationMode sat, bool hasRelu) {
2440 using RndMode = NVVM::FPRoundingMode;
2441 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2442 switch (rnd) {
2443 case RndMode::RN:
2444 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
2445 case RndMode::RZ:
2446 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
2447 case RndMode::RNA:
2448 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
2449 default:
2450 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
2451 }
2452}
2453
2455ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
2457 llvm::IRBuilderBase &builder) {
2459 args.push_back(mt.lookupValue(op.getA()));
2460 args.push_back(mt.lookupValue(op.getB()));
2461
2462 bool hasRelu = op.getRelu();
2463
2464 llvm::Intrinsic::ID intId =
2465 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
2466 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
2467
2468 return {intId, std::move(args)};
2469}
2470
2471#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
2472 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
2473 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
2474
2475llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
2476 bool hasRelu) {
2478 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2479 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
2480 })
2481 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2482 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
2483 })
2484 .Default([](mlir::Type) {
2485 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
2486 return llvm::Intrinsic::not_intrinsic;
2487 });
2488}
2489
2490#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
2491 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
2492 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
2493
2494#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
2495 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
2496 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
2497
2498llvm::Intrinsic::ID
2499ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2500 NVVM::SaturationMode sat, bool hasRelu) {
2501 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2502 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2503 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2504
2506 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2507 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
2508 })
2509 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2510 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
2511 })
2512 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2513 if (hasRoundingModeRZ)
2514 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
2515 else if (hasRoundingModeRP)
2516 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
2517
2518 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
2519 })
2520 .Default([](mlir::Type) {
2521 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
2522 return llvm::Intrinsic::not_intrinsic;
2523 });
2524}
2525
2526#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
2527 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2528 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2529
2530llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
2531 bool hasRelu) {
2533 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2534 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
2535 })
2536 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2537 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
2538 })
2539 .Default([](mlir::Type) {
2540 llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
2541 return llvm::Intrinsic::not_intrinsic;
2542 });
2543}
2544
2545#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
2546 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
2547 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
2548
2549llvm::Intrinsic::ID
2550ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2551 NVVM::SaturationMode sat) {
2552 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2553 switch (rnd) {
2554 case NVVM::FPRoundingMode::RZ:
2555 return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
2556 case NVVM::FPRoundingMode::RP:
2557 return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
2558 default:
2559 llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
2560 }
2561}
2562
2563NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
2564 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2565 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
2566
2567 bool hasRelu = curOp.getRelu();
2568
2569 llvm::Intrinsic::ID intId =
2571 .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
2572 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
2573 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
2574 })
2575 .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
2576 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
2577 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
2578 })
2579 .Default([](mlir::Type type) {
2580 llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
2581 return llvm::Intrinsic::not_intrinsic;
2582 });
2583
2584 llvm::Value *packedI16 =
2585 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2586 llvm::Type::getInt16Ty(builder.getContext()));
2587
2588 return {intId, {packedI16}};
2589}
2590
2591NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
2592 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2593 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
2594
2595 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
2596 llvm::Value *packedI16 =
2597 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2598 llvm::Type::getInt16Ty(builder.getContext()));
2599
2600 return {intId, {packedI16}};
2601}
2602
2603NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
2604 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2605 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
2606
2607 bool hasRelu = curOp.getRelu();
2608
2609 llvm::Intrinsic::ID intId =
2611 .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
2612 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
2613 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
2614 })
2615 .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
2616 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
2617 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
2618 })
2619 .Default([](mlir::Type type) {
2620 llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
2621 return llvm::Intrinsic::not_intrinsic;
2622 });
2623
2624 llvm::Value *packedI16 =
2625 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
2626 llvm::Type::getInt16Ty(builder.getContext()));
2627
2628 return {intId, {packedI16}};
2629}
2630
2631NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
2632 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2633 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
2634
2635 bool hasRelu = curOp.getRelu();
2636
2637 llvm::Intrinsic::ID intId =
2639 .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
2640 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
2641 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
2642 })
2643 .Default([](mlir::Type type) {
2644 llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
2645 return llvm::Intrinsic::not_intrinsic;
2646 });
2647
2648 llvm::Value *extendedI16 =
2649 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
2650 llvm::Type::getInt16Ty(builder.getContext()));
2651
2652 return {intId, {extendedI16}};
2653}
2654
2655llvm::Intrinsic::ID
2656Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
2659 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
2660 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2661 .getAddressSpace();
2662 bool isShared = as == NVVMMemorySpace::Shared;
2663 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2664
2665 llvm::Intrinsic::ID id;
2666 if (isShared) {
2667 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
2668 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
2669 } else {
2670 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
2671 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
2672 }
2673
2674 // Fill the Intrinsic Args
2675 args.push_back(mt.lookupValue(curOp.getAddr()));
2676 args.push_back(mt.lookupValue(curOp.getNCols()));
2677
2678 return id;
2679}
2680
2681llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
2684 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
2685 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
2686 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
2687 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
2688
2689 // Fill the Intrinsic Args
2690 args.push_back(mt.lookupValue(curOp.getTaddr()));
2691 args.push_back(mt.lookupValue(curOp.getNCols()));
2692
2693 return id;
2694}
2695
2696#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
2697 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
2698 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
2699
2700#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
2701 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
2702 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
2703
2704llvm::Intrinsic::ID
2705Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
2708 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
2709 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2710 .getAddressSpace();
2711 bool isShared = as == NVVMMemorySpace::Shared;
2712 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
2713 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2714
2715 llvm::Intrinsic::ID id =
2716 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
2717 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
2718
2719 // Fill the Intrinsic Args
2720 args.push_back(mt.lookupValue(curOp.getAddr()));
2721 if (hasMulticast)
2722 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
2723
2724 return id;
2725}
2726
2727#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
2728 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
2729
2730#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
2731 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
2732 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
2733
2734#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
2735 [&]() -> auto { \
2736 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
2737 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
2738 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
2739 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
2740 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
2741 }()
2742
2743llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
2744 bool hasRelu = getRelu();
2745 bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2746
2747 if (hasRelu && hasSatFinite)
2748 return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
2749 if (hasRelu)
2750 return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
2751 if (hasSatFinite)
2752 return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
2753 return llvm::Intrinsic::nvvm_ff2f16x2_rs;
2754}
2755
2756llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
2757 bool hasRelu = getRelu();
2758 bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2759
2760 if (hasRelu && hasSatFinite)
2761 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
2762 if (hasRelu)
2763 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
2764 if (hasSatFinite)
2765 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
2766 return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
2767}
2768
2769llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
2770 mlir::Type dstTy = getDstTy();
2771 bool hasRelu = getRelu();
2772
2774 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2775 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
2776 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
2777 })
2778 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2779 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
2780 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
2781 })
2782 .Default([](mlir::Type) {
2783 llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
2784 return llvm::Intrinsic::not_intrinsic;
2785 });
2786}
2787
2788llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
2789 mlir::Type dstTy = getDstTy();
2790 bool hasRelu = getRelu();
2791
2793 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2794 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
2795 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
2796 })
2797 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2798 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
2799 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
2800 })
2801 .Default([](mlir::Type) {
2802 llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
2803 return llvm::Intrinsic::not_intrinsic;
2804 });
2805}
2806
2807llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
2808 mlir::Type dstTy = getDstTy();
2809 bool hasRelu = getRelu();
2810
2812 .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
2813 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
2814 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
2815 })
2816 .Default([](mlir::Type) {
2817 llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
2818 return llvm::Intrinsic::not_intrinsic;
2819 });
2820}
2821
2822llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
2823 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
2824 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
2825 auto srcFmt = curOp.getSrcFormat();
2826 auto mc = curOp.getMulticast();
2827
2828 switch (curOp.getShape()) {
2829 case Tcgen05CpShape::SHAPE_128x256b:
2830 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
2831 case Tcgen05CpShape::SHAPE_128x128b:
2832 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
2833 case Tcgen05CpShape::SHAPE_4x256b:
2834 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
2835 case Tcgen05CpShape::SHAPE_32x128b:
2836 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
2837 case Tcgen05CpShape::SHAPE_64x128b:
2838 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
2839 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
2840 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
2841 }
2842 llvm_unreachable("Invalid shape in tcgen05 cp Op");
2843}
2844
2845// Returns the valid vector length for a given shape and vector length, the
2846// function models the table mentioned in the tcgen05.{ld, st} Op description
2847static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
2848 unsigned vecLen) {
2849 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
2850 return vecLen >= 2;
2851 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
2852 return vecLen >= 4;
2853 return true;
2854}
2855
2856LogicalResult Tcgen05LdOp::verify() {
2857 LogicalResult result = success();
2858 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2859 result = emitError("shape 16x32bx2 requires offset argument");
2860
2861 if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
2862 result = emitError("offset argument is only supported for shape 16x32bx2");
2863
2864 auto resTy = getRes().getType();
2865 unsigned resLen = isa<VectorType>(resTy)
2866 ? llvm::cast<VectorType>(resTy).getNumElements()
2867 : 1;
2868 if (!isValidVectorLength(getShape(), resLen))
2869 result = emitError(llvm::formatv("invalid result type length {0} for shape "
2870 "{1} in tcgen05.ld Op",
2871 resLen, stringifyEnum(getShape())));
2872
2873 return result;
2874}
2875
2876LogicalResult Tcgen05StOp::verify() {
2877 LogicalResult result = success();
2878 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2879 result = emitError("shape 16x32bx2 requires offset argument");
2880
2881 auto valTy = getVal().getType();
2882 unsigned valLen = isa<VectorType>(valTy)
2883 ? llvm::cast<VectorType>(valTy).getNumElements()
2884 : 1;
2885 if (!isValidVectorLength(getShape(), valLen))
2886 result = emitError(llvm::formatv("invalid input length {0} for shape "
2887 "{1} in tcgen05.st Op",
2888 valLen, stringifyEnum(getShape())));
2889
2890 return result;
2891}
2892
2893/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
2894/// have ConstantRangeAttr.
2897 SetIntRangeFn setResultRanges) {
2898 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
2899 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
2900 rangeAttr.getLower(), rangeAttr.getUpper()});
2901 }
2902}
2903
2904/// Verify the range attribute satisfies LLVM ConstantRange constructor
2905/// requirements for NVVM SpecialRangeableRegisterOp.
2906static LogicalResult
2908 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
2909 if (!rangeAttr)
2910 return success();
2911
2912 const llvm::APInt &lower = rangeAttr->getLower();
2913 const llvm::APInt &upper = rangeAttr->getUpper();
2914
2915 // Check LLVM ConstantRange constructor condition
2916 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
2917 unsigned bitWidth = lower.getBitWidth();
2918 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
2919 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
2920 return op->emitOpError(
2921 "invalid range attribute: Lower == Upper, but they aren't min (")
2922 << llvm::toString(minVal, 10, false) << ") or max ("
2923 << llvm::toString(maxVal, 10, false)
2924 << ") value! This is an invalid constant range.";
2925 }
2926
2927 return success();
2928}
2929
2930static llvm::Value *getAsPackedI32(llvm::Value *arg,
2931 llvm::IRBuilderBase &builder) {
2932 return builder.CreateBitCast(arg,
2933 llvm::Type::getInt32Ty(builder.getContext()));
2934}
2935
2936NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
2937 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2938 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
2939
2941 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2942 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2943 args.push_back(mt.lookupValue(curOp.getC()));
2944
2945 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2946 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2947 unsigned type = (isASigned << 1) | isBSigned;
2948 const llvm::Intrinsic::ID ids[] = {
2949 llvm::Intrinsic::nvvm_idp4a_u_u,
2950 llvm::Intrinsic::nvvm_idp4a_u_s,
2951 llvm::Intrinsic::nvvm_idp4a_s_u,
2952 llvm::Intrinsic::nvvm_idp4a_s_s,
2953 };
2954 return {ids[type], args};
2955}
2956
2957NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
2958 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2959 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2960
2962 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2963 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2964 args.push_back(builder.getInt1(curOp.getBHi()));
2965 args.push_back(mt.lookupValue(curOp.getC()));
2966
2967 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2968 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2969 unsigned type = (isASigned << 1) | isBSigned;
2970 const llvm::Intrinsic::ID ids[] = {
2971 llvm::Intrinsic::nvvm_idp2a_u_u,
2972 llvm::Intrinsic::nvvm_idp2a_u_s,
2973 llvm::Intrinsic::nvvm_idp2a_s_u,
2974 llvm::Intrinsic::nvvm_idp2a_s_s,
2975 };
2976 return {ids[type], args};
2977}
2978
2979static llvm::Value *getParamCastedAddr(llvm::Value *addr,
2980 llvm::IRBuilderBase &builder) {
2981 return builder.CreateAddrSpaceCast(
2982 addr,
2983 llvm::PointerType::get(builder.getContext(),
2984 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2985}
2986
2988PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2990 llvm::IRBuilderBase &builder) {
2991 using MemSpace = NVVM::NVVMMemorySpace;
2992 using CacheLevel = NVVM::PrefetchCacheLevel;
2993
2994 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2995 std::optional<NVVM::CacheEvictionPriority> evictPriority =
2996 op.getEvictPriority();
2997 unsigned addressSpace =
2998 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2999 .getAddressSpace();
3000
3002 llvm::Value *addr = mt.lookupValue(op.getAddr());
3003 args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
3004 : addr);
3005
3006 if (op.getTensormap())
3007 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
3008
3009 assert(cacheLevel && "expected cache level for non-tensormap prefetch");
3010
3011 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
3012 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
3013
3014 if (evictPriority && *cacheLevel == CacheLevel::L2) {
3015 switch (*evictPriority) {
3016 case NVVM::CacheEvictionPriority::EvictLast:
3017 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
3018 case NVVM::CacheEvictionPriority::EvictNormal:
3019 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
3020 default:
3021 llvm_unreachable("Invalid cache eviction priority");
3022 }
3023 }
3024
3025 switch (static_cast<MemSpace>(addressSpace)) {
3026 case MemSpace::Generic:
3027 return *cacheLevel == CacheLevel::L1
3028 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
3029 : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
3030 case MemSpace::Global:
3031 return *cacheLevel == CacheLevel::L1
3033 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
3034 : NVVM::IDArgPair(
3035 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
3036 case MemSpace::Local:
3037 return *cacheLevel == CacheLevel::L1
3039 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
3040 : NVVM::IDArgPair(
3041 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
3042 default:
3043 llvm_unreachable("Invalid pointer address space");
3044 }
3045}
3046
3047bool NVVM::InlinePtxOp::getAsmValues(
3048 RewriterBase &rewriter,
3049 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3050 &asmValues) {
3051 for (auto arg : getReadWriteArgs())
3052 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
3053 for (auto arg : getResults())
3054 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
3055 for (auto arg : getReadOnlyArgs())
3056 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
3057 if (getPredicate())
3058 asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
3059 return false; // No manual mapping needed
3060}
3061
3062NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
3063 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3064 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
3066 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
3067 args.push_back(mt.lookupValue(curOp.getMbarrier()));
3068
3069 llvm::Intrinsic::ID intrinsicID =
3070 curOp.getMulticast()
3071 ? llvm::Intrinsic::
3072 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
3073 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
3074
3075 return {intrinsicID, args};
3076}
3077
3078NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
3079 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3080 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
3082 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
3083
3084 llvm::Intrinsic::ID intrinsicID;
3085
3086 switch (curOp.getQueryType()) {
3087 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
3088 intrinsicID =
3089 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
3090 break;
3091 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
3092 intrinsicID = llvm::Intrinsic::
3093 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
3094 break;
3095 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
3096 intrinsicID = llvm::Intrinsic::
3097 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
3098 break;
3099 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
3100 intrinsicID = llvm::Intrinsic::
3101 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
3102 break;
3103 }
3104 return {intrinsicID, args};
3105}
3106
3107//===----------------------------------------------------------------------===//
3108// NVVM tcgen05.mma functions
3109//===----------------------------------------------------------------------===//
3110
3112Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3113 llvm::IRBuilderBase &builder) {
3114
3115 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
3117
3118 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
3119
3120 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
3121 const bool isATensor = isa<llvm::PointerType>(A->getType());
3122 args.push_back(A);
3123
3124 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
3125 args.push_back(mt.lookupValue(thisOp.getIdesc()));
3126 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
3127
3128 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
3129 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
3130 using IsATensorArray = std::array<CtaGroupArray, 2>;
3131 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
3132 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
3133
3134 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
3135 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
3136 { // without diable output lane
3137 {{// without scale input D
3138 {{
3139 // shared
3140 {{// cg1
3141 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
3142 // cg2
3143 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
3144 {{// tensor
3145 {
3146 // cg1
3147 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
3148 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
3149 },
3150 {
3151 // cg2
3152 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
3153 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
3154 }}},
3155 }},
3156 // with scale input D
3157 {{ // shared
3158 {{// cg1
3159 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
3160 // cg2
3161 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
3162 {{// tensor
3163 {
3164 // cg1
3165 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
3166 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
3167 },
3168 {
3169 // cg2
3170 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
3171 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
3172 }}}}}}},
3173 // with disable output lane
3174 {{ // without scale input D
3175 {{ // shared
3176 {{// cg1
3177 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
3178 notIntrinsic},
3179 // cg2
3180 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
3181 notIntrinsic}}},
3182 {{// cg1
3183 {
3184 llvm::Intrinsic::
3185 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
3186 llvm::Intrinsic::
3187 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
3188 },
3189 // cg2
3190 {
3191 llvm::Intrinsic::
3192 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
3193 llvm::Intrinsic::
3194 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
3195 }}}}},
3196 // with scale input D
3197 {{ // shared
3198 {{// cg1
3199 {llvm::Intrinsic::
3200 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
3201 notIntrinsic},
3202 // cg2
3203 {llvm::Intrinsic::
3204 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
3205 notIntrinsic}}},
3206 // tensor
3207 {{// cg1
3208 {llvm::Intrinsic::
3209 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
3210 llvm::Intrinsic::
3211 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
3212 // cg2
3213 {
3214 llvm::Intrinsic::
3215 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
3216 llvm::Intrinsic::
3217 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
3218 }}}}}}}}};
3219
3220 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
3221 bool hasScaleInputD = ScaleInputD != nullptr;
3222
3223 llvm::Value *DisableOutputLane =
3224 mt.lookupValue(thisOp.getDisableOutputLane());
3225 bool hasDisableOutputLane = DisableOutputLane != nullptr;
3226
3227 const unsigned ctaGroup =
3228 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
3229
3230 llvm::Intrinsic::ID ID =
3231 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
3232 [ctaGroup - 1][thisOp.getAShift()];
3233
3234 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
3235
3236 if (hasScaleInputD)
3237 args.push_back(ScaleInputD);
3238
3239 if (hasDisableOutputLane)
3240 args.push_back(DisableOutputLane);
3241
3242 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
3243
3244 if (!hasDisableOutputLane)
3245 args.push_back(builder.getInt32(ctaGroup));
3246
3247 args.push_back(
3248 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
3249
3250 return {ID, args};
3251}
3252
3253static LogicalResult
3254verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
3255 NVVM::CTAGroupKind ctaGroup, bool hasAShift,
3256 NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
3257
3258 if (disableOutputLane) {
3259 mlir::VectorType disableOutputLaneType =
3260 cast<mlir::VectorType>(disableOutputLane.getType());
3261 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
3262 disableOutputLaneType.getNumElements() != 4) ||
3263 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
3264 disableOutputLaneType.getNumElements() != 8))
3265 return emitError(loc) << "Disable Output Lane of length "
3266 << disableOutputLaneType.getNumElements()
3267 << " is incompatible with CtaGroupAttr";
3268 }
3269
3270 if (hasAShift && !isATensor)
3271 return emitError(
3272 loc, "A-shift can be applied only when matrix A is in tensor memory");
3273
3274 if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
3275 collectorOp == Tcgen05MMACollectorOp::USE))
3276 return emitError(
3277 loc, "Cannot use collector buffer operation fill or use with ashift");
3278
3279 return success();
3280}
3281
3282LogicalResult Tcgen05MMAOp::verify() {
3283 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
3284 getDisableOutputLane(), getCtaGroup(), getAShift(),
3285 getCollectorOp(), getLoc());
3286}
3287
3288//===----------------------------------------------------------------------===//
3289// NVVM tcgen05.mma.sp functions
3290//===----------------------------------------------------------------------===//
3291
3292mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
3293 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3294
3295 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
3297
3298 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
3299
3300 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
3301 bool isATensor = isa<llvm::PointerType>(A->getType());
3302 args.push_back(A);
3303
3304 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
3305 args.push_back(mt.lookupValue(thisOp.getIdesc()));
3306 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
3307 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
3308
3309 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
3310 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
3311 using IsATensorArray = std::array<CtaGroupArray, 2>;
3312 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
3313 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
3314
3315 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
3316 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
3317 { // without diable output lane
3318 {{// without scale input D
3319 {{
3320 // shared
3321 {{// cg1
3322 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
3323 // cg2
3324 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
3325 {{// tensor
3326 {
3327 // cg1
3328 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
3329 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
3330 },
3331 {
3332 // cg2
3333 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
3334 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
3335 }}},
3336 }},
3337 // with scale input D
3338 {{ // shared
3339 {{// cg1
3340 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
3341 notIntrinsic},
3342 // cg2
3343 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
3344 notIntrinsic}}},
3345 {{// tensor
3346 {
3347 // cg1
3348 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
3349 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
3350 },
3351 {
3352 // cg2
3353 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
3354 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
3355 }}}}}}},
3356 // with disable output lane
3357 {{ // without scale input D
3358 {{ // shared
3359 {{// cg1
3360 {llvm::Intrinsic::
3361 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
3362 notIntrinsic},
3363 // cg2
3364 {llvm::Intrinsic::
3365 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
3366 notIntrinsic}}},
3367 {{// cg1
3368 {
3369 llvm::Intrinsic::
3370 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
3371 llvm::Intrinsic::
3372 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
3373 },
3374 // cg2
3375 {
3376 llvm::Intrinsic::
3377 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
3378 llvm::Intrinsic::
3379 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
3380 }}}}},
3381 // with scale input D
3382 {{ // shared
3383 {{// cg1
3384 {llvm::Intrinsic::
3385 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
3386 notIntrinsic},
3387 // cg2
3388 {llvm::Intrinsic::
3389 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
3390 notIntrinsic}}},
3391 // tensor
3392 {{// cg1
3393 {llvm::Intrinsic::
3394 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
3395 llvm::Intrinsic::
3396 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
3397 // cg2
3398 {
3399 llvm::Intrinsic::
3400 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
3401 llvm::Intrinsic::
3402 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
3403 }}}}}}}}};
3404
3405 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
3406 bool hasScaleInputD = ScaleInputD != nullptr;
3407
3408 llvm::Value *DisableOutputLane =
3409 mt.lookupValue(thisOp.getDisableOutputLane());
3410 bool hasDisableOutputLane = DisableOutputLane != nullptr;
3411
3412 unsigned ctaGroup =
3413 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
3414
3415 llvm::Intrinsic::ID ID =
3416 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
3417 [ctaGroup - 1][thisOp.getAShift()];
3418
3419 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
3420
3421 if (hasScaleInputD)
3422 args.push_back(ScaleInputD);
3423
3424 if (hasDisableOutputLane)
3425 args.push_back(DisableOutputLane);
3426
3427 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
3428
3429 if (!hasDisableOutputLane)
3430 args.push_back(builder.getInt32(ctaGroup));
3431
3432 args.push_back(
3433 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
3434
3435 return {ID, args};
3436}
3437
3438LogicalResult Tcgen05MMASparseOp::verify() {
3439 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
3440 getDisableOutputLane(), getCtaGroup(), getAShift(),
3441 getCollectorOp(), getLoc());
3442}
3443
3444//===----------------------------------------------------------------------===//
3445// NVVM tcgen05.mma.block_scale functions
3446//===----------------------------------------------------------------------===//
3447
3448mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
3449 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3450
3451 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
3453
3454 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
3455
3456 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
3457 bool isATensor = isa<llvm::PointerType>(A->getType());
3458 args.push_back(A);
3459
3460 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
3461 args.push_back(mt.lookupValue(thisOp.getIdesc()));
3462 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
3463 args.push_back(mt.lookupValue(thisOp.getScaleA()));
3464 args.push_back(mt.lookupValue(thisOp.getScaleB()));
3465 args.push_back(builder.getInt32(
3466 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
3467 args.push_back(
3468 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
3469
3470 auto kind = thisOp.getKind();
3471 auto blockScale = thisOp.getBlockScale();
3472 llvm::Intrinsic::ID ID = [&]() {
3473 if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
3474 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
3475 return isATensor ? llvm::Intrinsic::
3476 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
3477 : llvm::Intrinsic::
3478 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
3479 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3480 return isATensor
3481 ? llvm::Intrinsic::
3482 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
3483 : llvm::Intrinsic::
3484 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
3485 }
3486 } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
3487 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
3488 return isATensor
3489 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
3490 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
3491 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3492 return isATensor ? llvm::Intrinsic::
3493 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
3494 : llvm::Intrinsic::
3495 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
3496 }
3497 } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
3498 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3499 return isATensor
3500 ? llvm::Intrinsic::
3501 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
3502 : llvm::Intrinsic::
3503 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
3504
3505 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
3506 return isATensor
3507 ? llvm::Intrinsic::
3508 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
3509 : llvm::Intrinsic::
3510 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
3511 }
3512 }
3513 llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
3514 }();
3515
3516 return {ID, args};
3517}
3518
3519static LogicalResult
3520verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp,
3521 NVVM::Tcgen05MMABlockScaleKind kind,
3522 NVVM::Tcgen05MMABlockScale blockScale,
3523 Location loc) {
3524
3525 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
3526 kind == Tcgen05MMABlockScaleKind::MXF4NVF4)
3527 return emitError(loc, "mxf4nvf4 requires block scale attribute");
3528
3529 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
3530 kind != Tcgen05MMABlockScaleKind::MXF4NVF4)
3531 return emitError(loc,
3532 llvm::formatv("{} kind does not support block16 attribute",
3533 stringifyEnum(kind)));
3534
3535 return success();
3536}
3537
3538LogicalResult Tcgen05MMABlockScaleOp::verify() {
3539 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
3540 getBlockScale(), getLoc());
3541}
3542
3543//===----------------------------------------------------------------------===//
3544// NVVM tcgen05.mma.sp.block_scale functions
3545//===----------------------------------------------------------------------===//
3546
3547mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
3548 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3549
3550 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
3552
3553 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
3554
3555 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
3556 bool isATensor = isa<llvm::PointerType>(A->getType());
3557 args.push_back(A);
3558
3559 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
3560 args.push_back(mt.lookupValue(thisOp.getIdesc()));
3561 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
3562 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
3563 args.push_back(mt.lookupValue(thisOp.getScaleA()));
3564 args.push_back(mt.lookupValue(thisOp.getScaleB()));
3565 args.push_back(builder.getInt32(
3566 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
3567 args.push_back(
3568 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
3569
3570 auto kind = thisOp.getKind();
3571 auto blockScale = thisOp.getBlockScale();
3572 llvm::Intrinsic::ID ID = [&]() {
3573 if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
3574 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
3575 return isATensor ? llvm::Intrinsic::
3576 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
3577 : llvm::Intrinsic::
3578 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
3579 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3580 return isATensor
3581 ? llvm::Intrinsic::
3582 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
3583 : llvm::Intrinsic::
3584 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
3585 }
3586 } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
3587 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
3588 return isATensor ? llvm::Intrinsic::
3589 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
3590 : llvm::Intrinsic::
3591 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
3592 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3593 return isATensor
3594 ? llvm::Intrinsic::
3595 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
3596 : llvm::Intrinsic::
3597 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
3598 }
3599 } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
3600 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3601 return isATensor
3602 ? llvm::Intrinsic::
3603 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
3604 : llvm::Intrinsic::
3605 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
3606
3607 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
3608 return isATensor
3609 ? llvm::Intrinsic::
3610 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
3611 : llvm::Intrinsic::
3612 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
3613 }
3614 }
3615 llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
3616 }();
3617
3618 return {ID, args};
3619}
3620
3621LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
3622 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
3623 getBlockScale(), getLoc());
3624}
3625
3626//===----------------------------------------------------------------------===//
3627// NVVM tcgen05.mma.ws functions
3628//===----------------------------------------------------------------------===//
3629
3630mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
3631 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3632
3633 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
3635
3636 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
3637
3638 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
3639 bool isATensor = isa<llvm::PointerType>(A->getType());
3640 args.push_back(A);
3641
3642 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
3643 args.push_back(mt.lookupValue(thisOp.getIdesc()));
3644 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
3645
3646 mlir::Value ZeroColMask = thisOp.getZeroColMask();
3647 llvm::Intrinsic::ID ID = notIntrinsic;
3648 if (ZeroColMask) {
3649 args.push_back(mt.lookupValue(ZeroColMask));
3650 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
3651 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
3652 } else
3653 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
3654 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
3655
3656 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
3657 args.push_back(
3658 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
3659 args.push_back(
3660 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
3661
3662 return {ID, args};
3663}
3664
3665//===----------------------------------------------------------------------===//
3666// NVVM tcgen05.mma.ws.sp functions
3667//===----------------------------------------------------------------------===//
3668
3669mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
3670 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3671
3672 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
3674
3675 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
3676
3677 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
3678 bool isATensor = isa<llvm::PointerType>(A->getType());
3679 args.push_back(A);
3680
3681 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
3682 args.push_back(mt.lookupValue(thisOp.getIdesc()));
3683 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
3684 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
3685
3686 mlir::Value ZeroColMask = thisOp.getZeroColMask();
3687 llvm::Intrinsic::ID ID = notIntrinsic;
3688 if (ZeroColMask) {
3689 args.push_back(mt.lookupValue(ZeroColMask));
3690 ID = isATensor
3691 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
3692 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
3693 } else
3694 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
3695 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
3696
3697 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
3698 args.push_back(
3699 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
3700 args.push_back(
3701 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
3702
3703 return {ID, args};
3704}
3705
3706//===----------------------------------------------------------------------===//
3707// NVVMDialect initialization, type parsing, and registration.
3708//===----------------------------------------------------------------------===//
3709
3710// TODO: This should be the llvm.nvvm dialect once this is supported.
3711void NVVMDialect::initialize() {
3712 addOperations<
3713#define GET_OP_LIST
3714#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
3715 >();
3716 addAttributes<
3717#define GET_ATTRDEF_LIST
3718#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
3719 >();
3720
3721 // Support unknown operations because not all NVVM operations are
3722 // registered.
3723 allowUnknownOperations();
3724 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
3725 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
3726}
3727
3728LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
3729 NamedAttribute attr) {
3730 StringAttr attrName = attr.getName();
3731 // Kernel function attribute should be attached to functions.
3732 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
3733 if (!isa<LLVM::LLVMFuncOp>(op)) {
3734 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
3735 << "' attribute attached to unexpected op";
3736 }
3737 }
3738 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
3739 // dim
3740 if (attrName == NVVMDialect::getMaxntidAttrName() ||
3741 attrName == NVVMDialect::getReqntidAttrName() ||
3742 attrName == NVVMDialect::getClusterDimAttrName()) {
3743 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
3744 if (!values || values.empty() || values.size() > 3) {
3745 return op->emitError()
3746 << "'" << attrName
3747 << "' attribute must be integer array with maximum 3 index";
3748 }
3749 }
3750 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
3751 // attribute
3752 if (attrName == NVVMDialect::getMinctasmAttrName() ||
3753 attrName == NVVMDialect::getMaxnregAttrName() ||
3754 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
3755 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
3756 return op->emitError()
3757 << "'" << attrName << "' attribute must be integer constant";
3758 }
3759 }
3760 // blocksareclusters must be used along with reqntid and cluster_dim
3761 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
3762 if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
3763 !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
3764 return op->emitError()
3765 << "'" << attrName << "' attribute must be used along with "
3766 << "'" << NVVMDialect::getReqntidAttrName() << "' and "
3767 << "'" << NVVMDialect::getClusterDimAttrName() << "'";
3768 }
3769 }
3770
3771 return success();
3772}
3773
3774LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
3775 unsigned regionIndex,
3776 unsigned argIndex,
3777 NamedAttribute argAttr) {
3778 auto funcOp = dyn_cast<FunctionOpInterface>(op);
3779 if (!funcOp)
3780 return success();
3781
3782 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
3783 StringAttr attrName = argAttr.getName();
3784 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
3785 if (!isKernel) {
3786 return op->emitError()
3787 << "'" << attrName
3788 << "' attribute must be present only on kernel arguments";
3789 }
3790 if (!isa<UnitAttr>(argAttr.getValue()))
3791 return op->emitError() << "'" << attrName << "' must be a unit attribute";
3792 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
3793 return op->emitError()
3794 << "'" << attrName
3795 << "' attribute requires the argument to also have attribute '"
3796 << LLVM::LLVMDialect::getByValAttrName() << "'";
3797 }
3798 }
3799
3800 return success();
3801}
3802
3803//===----------------------------------------------------------------------===//
3804// NVVM Address Space Attr
3805//===----------------------------------------------------------------------===//
3806
3807unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
3808 return static_cast<unsigned>(getValue());
3809}
3810
3811bool NVVMMemorySpaceAttr::isValidLoad(
3812 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
3813 const ::mlir::DataLayout *dataLayout,
3815 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
3816 dataLayout, emitError);
3817}
3818
3819bool NVVMMemorySpaceAttr::isValidStore(
3820 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
3821 const ::mlir::DataLayout *dataLayout,
3823 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
3824 dataLayout, emitError);
3825}
3826
3827bool NVVMMemorySpaceAttr::isValidAtomicOp(
3828 ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
3829 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
3831 // TODO: update this method once `ptr.atomic_rmw` is implemented.
3832 assert(false && "unimplemented, see TODO in the source.");
3833 return false;
3834}
3835
3836bool NVVMMemorySpaceAttr::isValidAtomicXchg(
3837 Type type, ptr::AtomicOrdering successOrdering,
3838 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
3839 const ::mlir::DataLayout *dataLayout,
3841 // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
3842 assert(false && "unimplemented, see TODO in the source.");
3843 return false;
3844}
3845
3846bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
3847 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
3848 // TODO: update this method once the `ptr.addrspace_cast` op is added to the
3849 // dialect.
3850 assert(false && "unimplemented, see TODO in the source.");
3851 return false;
3852}
3853
3854bool NVVMMemorySpaceAttr::isValidPtrIntCast(
3855 Type intLikeTy, Type ptrLikeTy,
3857 // TODO: update this method once the int-cast ops are added to the `ptr`
3858 // dialect.
3859 assert(false && "unimplemented, see TODO in the source.");
3860 return false;
3861}
3862
3863//===----------------------------------------------------------------------===//
3864// NVVM target attribute.
3865//===----------------------------------------------------------------------===//
3866LogicalResult
3867NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
3868 int optLevel, StringRef triple, StringRef chip,
3869 StringRef features, DictionaryAttr flags,
3870 ArrayAttr files, bool verifyTarget) {
3871 if (optLevel < 0 || optLevel > 3) {
3872 emitError() << "The optimization level must be a number between 0 and 3.";
3873 return failure();
3874 }
3875 if (triple.empty()) {
3876 emitError() << "The target triple cannot be empty.";
3877 return failure();
3878 }
3879 if (chip.empty()) {
3880 emitError() << "The target chip cannot be empty.";
3881 return failure();
3882 }
3883 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
3884 return mlir::isa_and_nonnull<StringAttr>(attr);
3885 })) {
3886 emitError() << "All the elements in the `link` array must be strings.";
3887 return failure();
3888 }
3889 return success();
3890}
3891
3892LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
3893 if (!getVerifyTarget())
3894 return success();
3895
3896 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
3897 if (!gpuModuleOp) {
3898 return emitError(gpuModule->getLoc(),
3899 "NVVM target attribute must be attached to a GPU module");
3900 }
3901
3902 const NVVMCheckSMVersion targetSMVersion =
3904 if (!targetSMVersion.isMinimumSMVersion()) {
3905 return emitError(gpuModule->getLoc(),
3906 "Minimum NVVM target SM version is sm_20");
3907 }
3908
3909 gpuModuleOp->walk([&](Operation *op) {
3910 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
3911 const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
3912 if (!requirement.isCompatibleWith(targetSMVersion)) {
3913 op->emitOpError() << "is not supported on " << getChip();
3914 return WalkResult::interrupt();
3915 }
3916 }
3917 return WalkResult::advance();
3918 });
3919
3920 return success();
3921}
3922
3923#define GET_OP_CLASSES
3924#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
3925
3926#define GET_ATTRDEF_CLASSES
3927#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
#define _none
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMABlockScaleKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
FloatType getF32Type()
Definition Builders.cpp:43
IntegerType getI32Type()
Definition Builders.cpp:63
FloatType getF16Type()
Definition Builders.cpp:39
MLIRContext * getContext() const
Definition Builders.h:56
FloatType getF64Type()
Definition Builders.cpp:45
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition NVVMDialect.h:53
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
This represents an operation in an abstracted form, suitable for use with the builder APIs.