MLIR 23.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
39#include <cassert>
40#include <optional>
41#include <string>
42
43using namespace mlir;
44using namespace NVVM;
45
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
50
51//===----------------------------------------------------------------------===//
52// Helper/Utility methods
53//===----------------------------------------------------------------------===//
54
55static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
57 return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
58}
59
61 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
62}
63
65 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
66}
67
69 return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
70}
71
72static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
73 llvm::Value *ptr,
74 NVVMMemorySpace targetAS) {
75 unsigned AS = static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
78}
79
80// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
81static llvm::nvvm::CTAGroupKind
82getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
83 switch (ctaGroup) {
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
88 }
89 llvm_unreachable("unsupported cta_group value");
90}
91
92//===----------------------------------------------------------------------===//
93// Verifier methods
94//===----------------------------------------------------------------------===//
95
96// This verifier is shared among the following Ops:
97// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
98// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
99static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
100 bool isIm2Col,
101 size_t numIm2ColOffsets,
102 Location loc) {
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc, "expects coordinates between 1 to 5 dimension");
105
106 // For Im2Col mode, there are two constraints:
107 if (isIm2Col) {
108 // 1. Tensor must always be at least 3-d.
109 if (tensorDims < 3)
110 return emitError(
111 loc,
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
113 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
115 return emitError(
116 loc, "im2col offsets must be 2 less than number of coordinates");
117 }
118 return success();
119}
120
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
123 // We lower through inline-ptx when getPredicate() is true.
124 // a) Only TILE mode is supported
125 // b) Cache-hint is not supported
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError("Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
131 }
132
133 size_t dims = getCoordinates().size();
134 switch (mode) {
135 case TMAStoreMode::TILE:
136 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
137 case TMAStoreMode::IM2COL:
138 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
139 case TMAStoreMode::TILE_SCATTER4:
140 if (dims != 5)
141 return emitError("Scatter4 mode expects 5 coordinates");
142 }
143 return success();
144}
145
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError("Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError("expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError("CG cache modifier is only support for 16 bytes copy.");
154 return success();
155}
156
157// This verify params can be shared across TMA Load and Prefetch Ops.
158static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
159 TMALoadMode mode, Location loc) {
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc, "expects coordinates between 1 to 5 dimension");
162
163 auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
166 return emitError(loc)
167 << "to use " << stringifyEnum(mode)
168 << " mode, the tensor has to be at least 3-dimensional";
169
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
172 << " (provided " << numIm2colOff << ")";
173
174 return success();
175 };
176
177 switch (mode) {
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode, false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode, true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode, true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode, false, 0)
188 : emitError(loc, "Gather4 mode expects 5 coordinates");
189 }
190 return success();
191}
192
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
194 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
195 getMode(), getLoc());
196}
197
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) { // Inline-asm based lowering
202 if (isCTAOnly)
203 return emitError("Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
205 return emitError(
206 "Predicate is supported only for Tile and Im2col modes.");
207 } else { // Intrinsics-based lowering
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
211 .getAddressSpace();
212 if (AS != expectedAS)
213 return emitError()
214 << (isCTAOnly
215 ? "Shared::cta destination requires address-space 3."
216 : "Shared::cluster destination requires address-space 7.");
217 // Checks specific to shared::cta mode
218 if (isCTAOnly) {
219 if (getMulticastMask())
220 return emitError("Multicast is not supported with shared::cta mode.");
221 if (getGroup())
222 return emitError("CTAGroup is not supported with shared::cta mode.");
223 }
224 }
225
226 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
227 getMode(), getLoc());
228}
229
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
232 size_t dims = getCoordinates().size();
233 switch (mode) {
234 case TMAStoreMode::TILE:
235 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
236 case TMAStoreMode::IM2COL:
237 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
240 }
241 return success();
242}
243
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
245 bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
246 if (isSharedCTA && getMulticastMask())
247 return emitError("Multicast is not supported with shared::cta mode.");
248
249 return success();
250}
251
252static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
253 NVVM::MemScopeKind scope,
254 Value retVal = nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->emitError("mbarrier scope must be either CTA or Cluster");
257
258 bool isSharedCluster = isPtrInSharedClusterSpace(addr);
259 bool hasRetValue = static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
261 return op->emitError(
262 "mbarrier in shared_cluster space cannot return any value");
263
264 return success();
265}
266
267LogicalResult MBarrierArriveOp::verify() {
268 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
269 getRes());
270}
271
272LogicalResult MBarrierArriveDropOp::verify() {
273 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
274 getRes());
275}
276
277LogicalResult MBarrierArriveExpectTxOp::verify() {
278 // The inline-ptx version of this Op does not support all features.
279 // With predicate, this Op lowers to inline-ptx. So, verify and
280 // error-out if there are unsupported features.
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError("mbarrier scope must be CTA when using predicate");
284
285 if (isPtrInSharedClusterSpace(getAddr()))
286 return emitError("mbarrier in shared_cluster space is not supported when "
287 "using predicate");
288
289 if (getRes())
290 return emitError("return-value is not supported when using predicate");
291
292 if (getRelaxed() == true)
293 return emitError("mbarrier with relaxed semantics is not supported when "
294 "using predicate");
295 }
296 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
297 getRes());
298}
299
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
301 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
302 getRes());
303}
304
305LogicalResult MBarrierExpectTxOp::verify() {
306 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
307}
308
309LogicalResult MBarrierCompleteTxOp::verify() {
310 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
311}
312
313LogicalResult MBarrierTestWaitOp::verify() {
314 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
315}
316
317LogicalResult MBarrierTryWaitOp::verify() {
318 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
319}
320
321LogicalResult ConvertFloatToTF32Op::verify() {
322 using RndMode = NVVM::FPRoundingMode;
323 switch (getRnd()) {
324 case RndMode::RNA:
325 if (getRelu())
326 return emitError("Relu not supported with rna rounding mode.");
327 break;
328 case RndMode::RN:
329 case RndMode::RZ:
330 break;
331 default:
332 return emitError(
333 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
334 }
335 return success();
336}
337
338LogicalResult ConvertF32x2ToF6x2Op::verify() {
340
341 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
342 return emitOpError("Only ")
343 << mlir::Float6E2M3FNType::get(ctx) << " and "
344 << mlir::Float6E3M2FNType::get(ctx)
345 << " types are supported for conversions from f32x2 to f6x2.";
346 }
347 return success();
348}
349
350LogicalResult ConvertF32x2ToF8x2Op::verify() {
351 using RndMode = NVVM::FPRoundingMode;
352 using SatMode = NVVM::SaturationMode;
353
354 bool isRoundingModeRN = getRnd() == RndMode::RN;
355 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
356 bool isRoundingModeRP = getRnd() == RndMode::RP;
357 bool isSatFinite = getSat() == SatMode::SATFINITE;
358
359 bool hasRelu = getRelu();
360
362
364 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
365 [&](mlir::Type) -> LogicalResult {
366 if (!isRoundingModeRN) {
367 return emitOpError("Only RN rounding mode is supported for "
368 "conversions from f32x2 to ")
369 << mlir::Float8E4M3FNType::get(ctx) << " and "
370 << mlir::Float8E5M2Type::get(ctx) << " types";
371 }
372 if (!isSatFinite) {
373 return emitOpError("Only SATFINITE saturation mode is supported "
374 "for conversions "
375 "from f32x2 to ")
376 << mlir::Float8E4M3FNType::get(ctx) << " and "
377 << mlir::Float8E5M2Type::get(ctx) << " types";
378 }
379 return success();
380 })
381 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
382 if (!(isRoundingModeRZ || isRoundingModeRP)) {
383 return emitOpError("Only RZ and RP rounding modes are supported for "
384 "conversions from f32x2 to ")
385 << mlir::Float8E8M0FNUType::get(ctx) << " type";
386 }
387 if (hasRelu) {
388 return emitOpError("relu not supported for conversions to ")
389 << mlir::Float8E8M0FNUType::get(ctx) << " type";
390 }
391 return success();
392 })
393 .Default([&](mlir::Type) {
394 return emitOpError("Only ")
395 << mlir::Float8E4M3FNType::get(ctx) << ", "
396 << mlir::Float8E5M2Type::get(ctx) << ", and "
397 << mlir::Float8E8M0FNUType::get(ctx)
398 << " types are "
399 "supported for conversions from f32x2 to f8x2";
400 });
401}
402
403LogicalResult ConvertF16x2ToF8x2Op::verify() {
405
406 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
407 return emitOpError("Only ")
408 << mlir::Float8E4M3FNType::get(ctx) << " and "
409 << mlir::Float8E5M2Type::get(ctx)
410 << " types are supported for conversions from f16x2 to f8x2.";
411 }
412 return success();
413}
414
415LogicalResult ConvertBF16x2ToF8x2Op::verify() {
416 using RndMode = NVVM::FPRoundingMode;
417
418 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
419 return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
420 << " type is supported for conversions from "
421 "bf16x2 to f8x2.";
422
423 auto rnd = getRnd();
424 if (rnd != RndMode::RZ && rnd != RndMode::RP)
425 return emitOpError("Only RZ and RP rounding modes are supported for "
426 "conversions from bf16x2 to f8x2.");
427
428 return success();
429}
430
431LogicalResult ConvertF32x2ToF4x2Op::verify() {
433
434 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
435 return emitOpError("Only ")
436 << mlir::Float4E2M1FNType::get(ctx)
437 << " type is supported for conversions from f32x2 to f4x2.";
438
439 return success();
440}
441
442LogicalResult ConvertF8x2ToF16x2Op::verify() {
444
445 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
446 return emitOpError("Only ")
447 << mlir::Float8E4M3FNType::get(ctx) << " and "
448 << mlir::Float8E5M2Type::get(ctx)
449 << " types are supported for conversions from f8x2 to f16x2.";
450
451 return success();
452}
453
454LogicalResult ConvertF8x2ToBF16x2Op::verify() {
456 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
457 return emitOpError("Only ")
458 << mlir::Float8E8M0FNUType::get(ctx)
459 << " type is supported for conversions from f8x2 to bf16x2.";
460
461 return success();
462}
463
464LogicalResult ConvertF6x2ToF16x2Op::verify() {
466
467 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
468 return emitOpError("Only ")
469 << mlir::Float6E2M3FNType::get(ctx) << " and "
470 << mlir::Float6E3M2FNType::get(ctx)
471 << " types are supported for conversions from f6x2 to f16x2.";
472
473 return success();
474}
475
476LogicalResult ConvertF4x2ToF16x2Op::verify() {
478
479 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
480 return emitOpError("Only ")
481 << mlir::Float4E2M1FNType::get(ctx)
482 << " type is supported for conversions from f4x2 to f16x2.";
483
484 return success();
485}
486
487LogicalResult PermuteOp::verify() {
488 using Mode = NVVM::PermuteMode;
489 bool hasHi = static_cast<bool>(getHi());
490
491 switch (getMode()) {
492 case Mode::DEFAULT:
493 case Mode::F4E:
494 case Mode::B4E:
495 if (!hasHi)
496 return emitError("mode '")
497 << stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
498 break;
499 case Mode::RC8:
500 case Mode::ECL:
501 case Mode::ECR:
502 case Mode::RC16:
503 if (hasHi)
504 return emitError("mode '") << stringifyPermuteMode(getMode())
505 << "' does not accept 'hi' operand.";
506 break;
507 }
508
509 return success();
510}
511
512//===----------------------------------------------------------------------===//
513// Stochastic Rounding Conversion Ops
514//===----------------------------------------------------------------------===//
515
516static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
517 FPRoundingMode rnd,
518 bool hasRandomBits,
519 Operation *op) {
520 static constexpr FPRoundingMode validRndModes[] = {
521 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
522
523 if (!llvm::is_contained(validRndModes, rnd)) {
524 return op->emitOpError(
525 "Only RN, RZ, and RS rounding modes are supported for "
526 "conversions from f32x2 to ")
527 << dstType << ".";
528 }
529
530 if (rnd == FPRoundingMode::RS) {
531 if (!hasRandomBits) {
532 return op->emitOpError("random_bits is required for RS rounding mode.");
533 }
534 } else {
535 if (hasRandomBits) {
536 return op->emitOpError(
537 "random_bits not supported for RN and RZ rounding modes.");
538 }
539 }
540
541 return success();
542}
543
544LogicalResult ConvertF32x2ToF16x2Op::verify() {
545 return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
546 getRandomBits() ? true : false, *this);
547}
548
549LogicalResult ConvertF32x2ToBF16x2Op::verify() {
550 return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
551 getRandomBits() ? true : false, *this);
552}
553
554LogicalResult ConvertF32x4ToF8x4Op::verify() {
556
557 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
558 return emitOpError("Only ")
559 << mlir::Float8E4M3FNType::get(ctx) << " and "
560 << mlir::Float8E5M2Type::get(ctx)
561 << " types are supported for conversions from f32x4 to f8x4.";
562
563 return success();
564}
565
566LogicalResult ConvertF32x4ToF6x4Op::verify() {
568
569 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
570 return emitOpError("Only ")
571 << mlir::Float6E2M3FNType::get(ctx) << " and "
572 << mlir::Float6E3M2FNType::get(ctx)
573 << " types are supported for conversions from f32x4 to f6x4.";
574
575 return success();
576}
577
578LogicalResult ConvertF32x4ToF4x4Op::verify() {
580
581 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
582 return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
583 << " type is supported for conversions from "
584 "f32x4 to f4x4.";
585
586 return success();
587}
588
589LogicalResult BulkStoreOp::verify() {
590 if (getInitVal() != 0)
591 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
592 return success();
593}
594
595LogicalResult PMEventOp::verify() {
596 auto eventId = getEventId();
597 auto maskedEventId = getMaskedEventId();
598 if (!maskedEventId && !eventId) {
599 return emitOpError() << "either `id` or `mask` must be set";
600 }
601
602 if (maskedEventId && eventId) {
603 return emitOpError() << "`id` and `mask` cannot be set at the same time";
604 }
605
606 if (eventId) {
607 if (eventId < 0 || eventId > 15) {
608 return emitOpError() << "`id` must be between 0 and 15";
609 }
610 }
611
612 return llvm::success();
613}
614
615// Given the element type of an operand and whether or not it is an accumulator,
616// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
617// operand's element type.
618std::optional<mlir::NVVM::MMATypes>
619MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
620 auto half2Type =
621 VectorType::get(2, Float16Type::get(operandElType.getContext()));
622 if (operandElType.isF64())
623 return NVVM::MMATypes::f64;
624 if (operandElType.isF16() || operandElType == half2Type)
625 return NVVM::MMATypes::f16;
626 if (operandElType.isF32() && isAccumulator)
627 return NVVM::MMATypes::f32;
628 if (operandElType.isF32() && !isAccumulator)
629 return NVVM::MMATypes::tf32;
630 if (llvm::isa<IntegerType>(operandElType)) {
631 if (isAccumulator)
632 return NVVM::MMATypes::s32;
633 return std::nullopt;
634 }
635
636 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
637 if (structType.getBody().empty())
638 return std::nullopt;
639 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
640 }
641
642 return std::nullopt;
643}
644
645static bool isInt4PtxType(MMATypes type) {
646 return (type == MMATypes::u4 || type == MMATypes::s4);
647}
648
649static bool isInt8PtxType(MMATypes type) {
650 return (type == MMATypes::u8 || type == MMATypes::s8);
651}
652
653static bool isIntegerPtxType(MMATypes type) {
654 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
655 type == MMATypes::s32;
656}
657
658MMATypes MmaOp::accumPtxType() {
659 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
660 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
661 assert(val.has_value() && "accumulator PTX type should always be inferrable");
662 return val.value();
663}
664
665MMATypes MmaOp::resultPtxType() {
666 std::optional<mlir::NVVM::MMATypes> val =
667 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
668 assert(val.has_value() && "result PTX type should always be inferrable");
669 return val.value();
670}
671
672void MmaOp::print(OpAsmPrinter &p) {
673 SmallVector<Type, 4> regTypes;
674 struct MMAOperandFragment {
675 StringRef operandName;
676 StringRef ptxTypeAttr;
677 SmallVector<Value, 4> regs;
678 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
679 : operandName(name), ptxTypeAttr(ptxTypeName) {}
680 };
681
682 std::array<MMAOperandFragment, 3> frags{
683 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
684 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
685 MMAOperandFragment("C", "")};
686 SmallVector<StringRef, 4> ignoreAttrNames{
687 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
688
689 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
690 auto &frag = frags[fragIdx];
691 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
692 for (auto operandIdx = varOperandSpec.first;
693 operandIdx < varOperandSpec.first + varOperandSpec.second;
694 operandIdx++) {
695 frag.regs.push_back(this->getOperand(operandIdx));
696 if (operandIdx == 0) {
697 regTypes.push_back(this->getOperand(operandIdx).getType());
698 }
699 }
700 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
701 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
702 if (inferredType)
703 ignoreAttrNames.push_back(frag.ptxTypeAttr);
704 }
705
706 auto printMmaOperand = [&](const MMAOperandFragment &frag) -> void {
707 p << " " << frag.operandName;
708 p << "[";
709 p.printOperands(frag.regs);
710 p << "] ";
711 };
712
713 for (const auto &frag : frags) {
714 printMmaOperand(frag);
715 }
716
717 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
718
719 // Print the types of the operands and result.
720 p << " : "
721 << "(";
722 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
723 frags[1].regs[0].getType(),
724 frags[2].regs[0].getType()},
725 p);
726 p << ")";
727 p.printArrowTypeList(TypeRange{this->getRes().getType()});
728}
729
730void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
731 ValueRange operandA, ValueRange operandB, ValueRange operandC,
732 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
733 std::optional<MMAIntOverflow> intOverflow,
734 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
735 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
736
737 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
738 MLIRContext *ctx = builder.getContext();
739 result.addAttribute(
740 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
741
742 result.addOperands(operandA);
743 result.addOperands(operandB);
744 result.addOperands(operandC);
745
746 if (multiplicandPtxTypes) {
747 result.addAttribute("multiplicandAPtxType",
748 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
749 result.addAttribute("multiplicandBPtxType",
750 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
751 } else {
752 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
753 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
754 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
755 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
756 }
757
758 if (multiplicandLayouts) {
759 result.addAttribute("layoutA",
760 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
761 result.addAttribute("layoutB",
762 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
763 } else {
764 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
765 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
766 }
767
768 if (intOverflow.has_value())
769 result.addAttribute("intOverflowBehavior",
770 MMAIntOverflowAttr::get(ctx, *intOverflow));
771 if (b1Op.has_value())
772 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
773
774 result.addTypes(resultType);
775 result.addAttribute(
776 MmaOp::getOperandSegmentSizeAttr(),
777 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
778 static_cast<int32_t>(operandB.size()),
779 static_cast<int32_t>(operandC.size())}));
780}
781
782// <operation> :=
783// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
784// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
785// `->` type($res)
786ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
787 struct MMAOperandFragment {
788 std::optional<MMATypes> elemtype;
789 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
790 SmallVector<Type> regTypes;
791 };
792
793 Builder &builder = parser.getBuilder();
794 std::array<MMAOperandFragment, 4> frags;
795
796 NamedAttrList namedAttributes;
797
798 // A helper to parse the operand segments.
799 auto parseMmaOperand = [&](StringRef operandName,
800 MMAOperandFragment &frag) -> LogicalResult {
801 if (parser.parseKeyword(operandName).failed())
802 return failure();
803 if (parser
804 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
805 .failed())
806 return failure();
807 return success();
808 };
809
810 // Parse the operand segments.
811 if (parseMmaOperand("A", frags[0]).failed())
812 return failure();
813 if (parseMmaOperand("B", frags[1]).failed())
814 return failure();
815 if (parseMmaOperand("C", frags[2]).failed())
816 return failure();
817
818 if (parser.parseOptionalAttrDict(namedAttributes).failed())
819 return failure();
820
821 // Parse the type specification and resolve operands.
822 SmallVector<Type, 3> operandTypes;
823 if (failed(parser.parseColon()))
824 return failure();
825 if (failed(parser.parseLParen()))
826 return failure();
827 if (failed(parser.parseTypeList(operandTypes)))
828 return failure();
829 if (failed(parser.parseRParen()))
830 if (operandTypes.size() != 3)
831 return parser.emitError(
832 parser.getNameLoc(),
833 "expected one type for each operand segment but got " +
834 Twine(operandTypes.size()) + " types");
835 for (const auto &iter : llvm::enumerate(operandTypes)) {
836 auto &frag = frags[iter.index()];
837 frag.regTypes.resize(frag.regs.size(), iter.value());
838 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
839 parser.getNameLoc(), result.operands)))
840 return failure();
841 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
842 /*isAccumulator*/ iter.index() < 2);
843 }
844
845 Type resultType;
846 if (parser.parseArrow() || parser.parseType(resultType))
847 return failure();
848 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
849
850 std::array<StringRef, 2> names{"multiplicandAPtxType",
851 "multiplicandBPtxType"};
852 for (unsigned idx = 0; idx < names.size(); idx++) {
853 const auto &frag = frags[idx];
854 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
855 if (!frag.elemtype.has_value() && !attr.has_value()) {
856 return parser.emitError(
857 parser.getNameLoc(),
858 "attribute " + names[idx] +
859 " is not provided explicitly and cannot be inferred");
860 }
861 if (!attr.has_value())
862 result.addAttribute(
863 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
864 }
865
866 result.addTypes(resultType);
867 if (!namedAttributes.empty())
868 result.addAttributes(namedAttributes);
869 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
870 builder.getDenseI32ArrayAttr({
871 static_cast<int32_t>(frags[0].regs.size()),
872 static_cast<int32_t>(frags[1].regs.size()),
873 static_cast<int32_t>(frags[2].regs.size()),
874 }));
875 return success();
876}
877
878LogicalResult MmaOp::verify() {
879 MLIRContext *context = getContext();
880 auto f16Ty = Float16Type::get(context);
881 auto i32Ty = IntegerType::get(context, 32);
882 auto f16x2Ty = VectorType::get(2, f16Ty);
883 auto f32Ty = Float32Type::get(context);
884 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
885 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
886
887 auto s32x4StructTy =
888 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
889 auto f32x8StructTy =
890 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
891 auto f16x2x2StructTy =
892 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
893 auto f32x4StructTy =
894 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
895 auto s32x2StructTy =
896 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
897
898 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
899 getShapeAttr().getK()};
900
901 // These variables define the set of allowed data types for matrices A, B, C,
902 // and result.
903 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
904 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
905 AllowedShapes allowedShapes;
906 AllowedTypes expectedA;
907 AllowedTypes expectedB;
908 AllowedTypes expectedC;
909 SmallVector<Type> expectedResult;
910
911 // When M = 16, we just need to calculate the number of 8xk tiles, where
912 // k is a factor that depends on the data type.
913 if (mmaShape[0] == 16) {
914 int64_t kFactor;
915 Type multiplicandFragType;
916 switch (*getMultiplicandAPtxType()) {
917 case MMATypes::tf32:
918 kFactor = 4;
919 multiplicandFragType = i32Ty;
920 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
921 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
922 break;
923 case MMATypes::bf16:
924 kFactor = 8;
925 multiplicandFragType = i32Ty;
926 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
927 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
928 break;
929 case MMATypes::f16:
930 kFactor = 8;
931 multiplicandFragType = f16x2Ty;
932 expectedResult.push_back(f16x2x2StructTy);
933 expectedResult.push_back(f32x4StructTy);
934 break;
935 case MMATypes::s4:
936 case MMATypes::u4:
937 kFactor = 32;
938 break;
939 case MMATypes::b1:
940 kFactor = 128;
941 break;
942 case MMATypes::s8:
943 case MMATypes::u8:
944 kFactor = 16;
945 break;
946 default:
947 return emitError("invalid shape or multiplicand type: " +
948 stringifyEnum(getMultiplicandAPtxType().value()));
949 }
950
951 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
952 expectedResult.push_back(s32x4StructTy);
953 expectedC.emplace_back(4, i32Ty);
954 multiplicandFragType = i32Ty;
955 } else {
956 expectedC.emplace_back(2, f16x2Ty);
957 expectedC.emplace_back(4, f32Ty);
958 }
959
960 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
961 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
962 expectedA.emplace_back(unitA, multiplicandFragType);
963 expectedB.emplace_back(unitB, multiplicandFragType);
964 allowedShapes.push_back({16, 8, kFactor});
965 allowedShapes.push_back({16, 8, kFactor * 2});
966
967 if (resultPtxType() != accumPtxType())
968 return emitOpError("ctype does not match dtype");
969 }
970
971 // In the M=8 case, there is only 1 possible case per data type.
972 if (mmaShape[0] == 8) {
973 if (*getMultiplicandAPtxType() == MMATypes::f16) {
974 expectedA.emplace_back(2, f16x2Ty);
975 expectedB.emplace_back(2, f16x2Ty);
976 expectedResult.push_back(f16x2x4StructTy);
977 expectedResult.push_back(f32x8StructTy);
978 expectedC.emplace_back(4, f16x2Ty);
979 expectedC.emplace_back(8, f32Ty);
980 allowedShapes.push_back({8, 8, 4});
981 }
982 if (*getMultiplicandAPtxType() == MMATypes::f64) {
983 Type f64Ty = Float64Type::get(context);
984 expectedA.emplace_back(1, f64Ty);
985 expectedB.emplace_back(1, f64Ty);
986 expectedC.emplace_back(2, f64Ty);
987 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
988 context, SmallVector<Type>(2, f64Ty)));
989 allowedShapes.push_back({8, 8, 4});
990 }
991 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
992 expectedA.push_back({i32Ty});
993 expectedB.push_back({i32Ty});
994 expectedC.push_back({i32Ty, i32Ty});
995 expectedResult.push_back(s32x2StructTy);
996 if (isInt4PtxType(getMultiplicandAPtxType().value()))
997 allowedShapes.push_back({8, 8, 32});
998 if (isInt8PtxType(getMultiplicandAPtxType().value()))
999 allowedShapes.push_back({8, 8, 16});
1000 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1001 allowedShapes.push_back({8, 8, 128});
1002 }
1003 }
1004
1005 std::string errorMessage;
1006 llvm::raw_string_ostream errorStream(errorMessage);
1007
1008 // Check that we matched an existing shape/dtype combination.
1009 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1010 !llvm::is_contained(allowedShapes, mmaShape)) {
1011 errorStream << "unimplemented variant for MMA shape <";
1012 llvm::interleaveComma(mmaShape, errorStream);
1013 errorStream << ">";
1014 return emitOpError(errorMessage);
1015 }
1016
1017 // Verify the operand types for segments of A, B, and C operands.
1018 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1019 for (const auto &iter : llvm::enumerate(
1020 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1021 auto spec = this->getODSOperandIndexAndLength(iter.index());
1022 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1023 operand_type_begin() + spec.first +
1024 spec.second);
1025 bool match = llvm::is_contained(iter.value(), operandTySeg);
1026
1027 if (!match) {
1028 errorStream << "Could not match types for the "
1029 << operandNames[iter.index()]
1030 << " operands; expected one of ";
1031 for (const auto &x : iter.value()) {
1032 errorStream << x.size() << "x" << x[0] << " ";
1033 }
1034 errorStream << "but got ";
1035 llvm::interleaveComma(operandTySeg, errorStream);
1036 return emitOpError(errorMessage);
1037 }
1038 }
1039
1040 // Check the result type
1041 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1042 return expectedResultType == getResult().getType();
1043 })) {
1044 errorStream
1045 << "Could not match allowed types for the result; expected one of ";
1046 llvm::interleaveComma(expectedResult, errorStream);
1047 errorStream << " but got " << getResult().getType();
1048 return emitOpError(errorMessage);
1049 }
1050
1051 // Ensure that binary MMA variants have a b1 MMA operation defined.
1052 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1053 return emitOpError("op requires " + getB1OpAttrName().strref() +
1054 " attribute");
1055 }
1056
1057 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1058 // attribute.
1059 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1060 isInt8PtxType(*getMultiplicandAPtxType())) {
1061 if (!getIntOverflowBehavior())
1062 return emitOpError("op requires " +
1063 getIntOverflowBehaviorAttrName().strref() +
1064 " attribute");
1065 }
1066
1067 // Validate layout combinations. According to the operation description, most
1068 // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
1069 // can use other layout combinations.
1070 bool isM8N8K4_F16 =
1071 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1072 getMultiplicandAPtxType() == MMATypes::f16);
1073
1074 if (!isM8N8K4_F16) {
1075 // For all other shapes/types, layoutA must be row and layoutB must be col
1076 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1077 return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
1078 "layoutB = #nvvm.mma_layout<col> for shape <")
1079 << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
1080 << "> with element types "
1081 << stringifyEnum(*getMultiplicandAPtxType()) << " and "
1082 << stringifyEnum(*getMultiplicandBPtxType())
1083 << ". Only m8n8k4 with f16 supports other layouts.";
1084 }
1085 }
1086
1087 return success();
1088}
1089
1090MMATypes MmaSpOp::accumPtxType() {
1091 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1092 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
1093 assert(val.has_value() && "accumulator PTX type should always be inferrable");
1094 return val.value();
1095}
1096
1097MMATypes MmaSpOp::resultPtxType() {
1098 std::optional<mlir::NVVM::MMATypes> val =
1099 MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
1100 assert(val.has_value() && "result PTX type should always be inferrable");
1101 return val.value();
1102}
1103
1105MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1106 llvm::IRBuilderBase &builder) {
1107 auto thisOp = cast<NVVM::MmaSpOp>(op);
1108
1109 // Get operands
1111 for (mlir::Value v : thisOp.getOperands())
1112 args.push_back(mt.lookupValue(v));
1113
1114 // Get intrinsic ID using the existing getIntrinsicID method
1115 auto intId = MmaSpOp::getIntrinsicID(
1116 thisOp.getShape().getM(), thisOp.getShape().getN(),
1117 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1118 thisOp.getOrderedMetadata(), thisOp.getKind(),
1119 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1120 thisOp.accumPtxType(), thisOp.resultPtxType());
1121
1122 return {intId, args};
1123}
1124
1125void MmaSpOp::print(OpAsmPrinter &p) {
1126 SmallVector<Type, 4> regTypes;
1127 struct MMAOperandFragment {
1128 StringRef operandName;
1129 StringRef ptxTypeAttr;
1130 SmallVector<Value, 4> regs;
1131 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1132 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1133 };
1134
1135 std::array<MMAOperandFragment, 5> frags{
1136 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1137 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1138 MMAOperandFragment("C", ""), MMAOperandFragment("sparseMetadata", ""),
1139 MMAOperandFragment("selector", "")};
1140 SmallVector<StringRef, 4> ignoreAttrNames{
1141 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1142
1143 // Handle variadic operands A, B, C
1144 for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1145 auto &frag = frags[fragIdx];
1146 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1147 for (auto operandIdx = varOperandSpec.first;
1148 operandIdx < varOperandSpec.first + varOperandSpec.second;
1149 operandIdx++) {
1150 frag.regs.push_back(this->getOperand(operandIdx));
1151 if (operandIdx == varOperandSpec.first) {
1152 regTypes.push_back(this->getOperand(operandIdx).getType());
1153 }
1154 }
1155 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1156 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
1157 if (inferredType)
1158 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1159 }
1160
1161 // Handle sparse metadata and selector (single operands)
1162 frags[3].regs.push_back(getSparseMetadata());
1163 frags[4].regs.push_back(getSparsitySelector());
1164
1165 auto printMmaSpOperand = [&](const MMAOperandFragment &frag) -> void {
1166 p << " " << frag.operandName;
1167 p << "[";
1168 p.printOperands(frag.regs);
1169 p << "]";
1170 };
1171
1172 for (const auto &frag : frags)
1173 printMmaSpOperand(frag);
1174
1175 p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
1176 p << " : ";
1177 p << "(";
1178 for (int i = 0; i < 3; ++i) {
1179 p << regTypes[i];
1180 if (i < 2)
1181 p << ", ";
1182 }
1183 p << ") -> " << getResult().getType();
1184}
1185
1186void MmaSpOp::build(
1187 OpBuilder &builder, OperationState &result, Type resultType,
1188 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1189 Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
1190 std::optional<MMAIntOverflow> intOverflow,
1191 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1192
1193 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1194 MLIRContext *ctx = builder.getContext();
1195 result.addAttribute(
1196 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1197
1198 result.addOperands(operandA);
1199 result.addOperands(operandB);
1200 result.addOperands(operandC);
1201 result.addOperands(sparseMetadata);
1202 result.addOperands(sparsitySelector);
1203
1204 if (multiplicandPtxTypes) {
1205 result.addAttribute("multiplicandAPtxType",
1206 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1207 result.addAttribute("multiplicandBPtxType",
1208 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1209 } else {
1210 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1211 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1212 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1213 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1214 }
1215
1216 if (intOverflow.has_value())
1217 result.addAttribute("intOverflowBehavior",
1218 MMAIntOverflowAttr::get(ctx, *intOverflow));
1219
1220 result.addTypes(resultType);
1221 result.addAttribute(
1222 MmaSpOp::getOperandSegmentSizeAttr(),
1223 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
1224 static_cast<int32_t>(operandB.size()),
1225 static_cast<int32_t>(operandC.size()), 1,
1226 1})); // sparseMetadata and sparsitySelector
1227}
1228
1229ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
1230 struct MMAOperandFragment {
1231 std::optional<MMATypes> elemtype;
1232 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1233 SmallVector<Type> regTypes;
1234 };
1235
1236 Builder &builder = parser.getBuilder();
1237 std::array<MMAOperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
1238
1239 NamedAttrList namedAttributes;
1240
1241 // A helper to parse the operand segments.
1242 auto parseMmaSpOperand = [&](StringRef operandName,
1243 MMAOperandFragment &frag) -> LogicalResult {
1244 if (parser.parseKeyword(operandName).failed())
1245 return failure();
1246 if (parser
1247 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
1248 .failed())
1249 return failure();
1250 return success();
1251 };
1252
1253 // Parse the operand segments.
1254 if (parseMmaSpOperand("A", frags[0]).failed())
1255 return failure();
1256 if (parseMmaSpOperand("B", frags[1]).failed())
1257 return failure();
1258 if (parseMmaSpOperand("C", frags[2]).failed())
1259 return failure();
1260 if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
1261 return failure();
1262 if (parseMmaSpOperand("selector", frags[4]).failed())
1263 return failure();
1264
1265 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1266 return failure();
1267
1268 // Parse the type specification and resolve operands.
1269 SmallVector<Type, 3> operandTypes;
1270 if (failed(parser.parseColon()))
1271 return failure();
1272 if (failed(parser.parseLParen()))
1273 return failure();
1274 if (failed(parser.parseTypeList(operandTypes)))
1275 return failure();
1276 if (failed(parser.parseRParen()))
1277 return failure();
1278 if (operandTypes.size() != 3)
1279 return parser.emitError(
1280 parser.getNameLoc(),
1281 "expected one type for each operand segment but got " +
1282 Twine(operandTypes.size()) + " types");
1283 for (const auto &iter : llvm::enumerate(operandTypes)) {
1284 auto &frag = frags[iter.index()];
1285 frag.regTypes.resize(frag.regs.size(), iter.value());
1286 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
1287 parser.getNameLoc(), result.operands)))
1288 return failure();
1289 frag.elemtype =
1290 MmaOp::inferOperandMMAType(frag.regTypes[0],
1291 /*isAccumulator*/ iter.index() >= 2);
1292 }
1293
1294 Type resultType;
1295 if (parser.parseArrow() || parser.parseType(resultType))
1296 return failure();
1297 frags[5].elemtype =
1298 MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
1299
1300 // Resolve sparse metadata and selector (assume i32 type)
1301 Type i32Type = builder.getIntegerType(32);
1302 if (parser
1303 .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
1304 result.operands)
1305 .failed())
1306 return failure();
1307 if (parser
1308 .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
1309 result.operands)
1310 .failed())
1311 return failure();
1312
1313 std::array<StringRef, 2> names{"multiplicandAPtxType",
1314 "multiplicandBPtxType"};
1315 for (unsigned idx = 0; idx < names.size(); idx++) {
1316 const auto &frag = frags[idx];
1317 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
1318 if (!frag.elemtype.has_value() && !attr.has_value()) {
1319 return parser.emitError(
1320 parser.getNameLoc(),
1321 "attribute " + names[idx] +
1322 " is not provided explicitly and cannot be inferred");
1323 }
1324 if (!attr.has_value())
1325 result.addAttribute(
1326 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
1327 }
1328
1329 result.addTypes(resultType);
1330 if (!namedAttributes.empty())
1331 result.addAttributes(namedAttributes);
1332 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1333 builder.getDenseI32ArrayAttr({
1334 static_cast<int32_t>(frags[0].regs.size()),
1335 static_cast<int32_t>(frags[1].regs.size()),
1336 static_cast<int32_t>(frags[2].regs.size()),
1337 1, // sparseMetadata
1338 1 // sparsitySelector
1339 }));
1340 return success();
1341}
1342
1343LogicalResult MmaSpOp::verify() {
1344 MLIRContext *context = getContext();
1345 auto f16Ty = Float16Type::get(context);
1346 auto i32Ty = IntegerType::get(context, 32);
1347 auto f16x2Ty = VectorType::get(2, f16Ty);
1348 auto f32Ty = Float32Type::get(context);
1349 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1350 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1351
1352 auto s32x4StructTy =
1353 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1354 auto f32x8StructTy =
1355 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
1356 auto f16x2x2StructTy =
1357 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1358 auto f32x4StructTy =
1359 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1360 auto s32x2StructTy =
1361 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1362
1363 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1364 getShapeAttr().getK()};
1365
1366 // These variables define the set of allowed data types for matrices A, B, C,
1367 // and result.
1368 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
1369 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
1370 AllowedShapes allowedShapes;
1371 AllowedTypes expectedA;
1372 AllowedTypes expectedB;
1373 AllowedTypes expectedC;
1374 SmallVector<Type> expectedResult;
1375
1376 // When M = 16, we just need to calculate the number of 8xk tiles, where
1377 // k is a factor that depends on the data type.
1378 if (mmaShape[0] == 16) {
1379 int64_t kFactor;
1380 Type multiplicandFragType;
1381 switch (*getMultiplicandAPtxType()) {
1382 case MMATypes::tf32:
1383 kFactor = 4;
1384 multiplicandFragType = i32Ty;
1385 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1386 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1387 // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
1388 allowedShapes.push_back({16, 8, 8});
1389 allowedShapes.push_back({16, 8, 16});
1390 break;
1391 case MMATypes::bf16:
1392 kFactor = 8;
1393 multiplicandFragType = i32Ty;
1394 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1395 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1396 // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
1397 allowedShapes.push_back({16, 8, 16});
1398 allowedShapes.push_back({16, 8, 32});
1399 break;
1400 case MMATypes::f16:
1401 kFactor = 8;
1402 multiplicandFragType = f16x2Ty;
1403 expectedResult.push_back(f16x2x2StructTy);
1404 expectedResult.push_back(f32x4StructTy);
1405 // Sparse MMA supports m16n8k16 and m16n8k32 for f16
1406 allowedShapes.push_back({16, 8, 16});
1407 allowedShapes.push_back({16, 8, 32});
1408 break;
1409 case MMATypes::s4:
1410 case MMATypes::u4:
1411 kFactor = 32;
1412 // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
1413 allowedShapes.push_back({16, 8, 64});
1414 allowedShapes.push_back({16, 8, 128});
1415 break;
1416 case MMATypes::s8:
1417 case MMATypes::u8:
1418 kFactor = 16;
1419 // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
1420 allowedShapes.push_back({16, 8, 32});
1421 allowedShapes.push_back({16, 8, 64});
1422 break;
1423 case MMATypes::e4m3:
1424 case MMATypes::e5m2:
1425 case MMATypes::e3m2:
1426 case MMATypes::e2m3:
1427 case MMATypes::e2m1:
1428 kFactor = 32;
1429 multiplicandFragType = i32Ty;
1430 expectedResult.push_back(f16x2x2StructTy);
1431 expectedResult.push_back(f32x4StructTy);
1432 // Sparse MMA supports m16n8k64 for FP8 types
1433 allowedShapes.push_back({16, 8, 64});
1434 break;
1435 default:
1436 return emitError("invalid shape or multiplicand type: " +
1437 stringifyEnum(getMultiplicandAPtxType().value()));
1438 }
1439
1440 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1441 expectedResult.push_back(s32x4StructTy);
1442 expectedC.emplace_back(4, i32Ty);
1443 multiplicandFragType = i32Ty;
1444 } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1445 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1446 // FP8 types
1447 expectedC.emplace_back(2, f16x2Ty);
1448 expectedC.emplace_back(4, f32Ty);
1449 } else {
1450 expectedC.emplace_back(2, f16x2Ty);
1451 expectedC.emplace_back(4, f32Ty);
1452 }
1453
1454 // For sparse MMA, A operand is compressed (2:4 sparsity means half the
1455 // elements)
1456 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1457 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1458 expectedA.emplace_back(unitA, multiplicandFragType);
1459 expectedB.emplace_back(unitB, multiplicandFragType);
1460
1461 if (resultPtxType() != accumPtxType())
1462 return emitOpError("ctype does not match dtype");
1463 }
1464
1465 // In the M=8 case, there is only 1 possible case per data type.
1466 if (mmaShape[0] == 8) {
1467 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1468 expectedA.emplace_back(2, f16x2Ty);
1469 expectedB.emplace_back(2, f16x2Ty);
1470 expectedResult.push_back(f16x2x4StructTy);
1471 expectedResult.push_back(f32x8StructTy);
1472 expectedC.emplace_back(4, f16x2Ty);
1473 expectedC.emplace_back(8, f32Ty);
1474 allowedShapes.push_back({8, 8, 4});
1475 }
1476 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1477 Type f64Ty = Float64Type::get(context);
1478 expectedA.emplace_back(1, f64Ty);
1479 expectedB.emplace_back(1, f64Ty);
1480 expectedC.emplace_back(2, f64Ty);
1481 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1482 context, SmallVector<Type>(2, f64Ty)));
1483 allowedShapes.push_back({8, 8, 4});
1484 }
1485 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1486 expectedA.push_back({i32Ty});
1487 expectedB.push_back({i32Ty});
1488 expectedC.push_back({i32Ty, i32Ty});
1489 expectedResult.push_back(s32x2StructTy);
1490 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1491 allowedShapes.push_back({8, 8, 32});
1492 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1493 allowedShapes.push_back({8, 8, 16});
1494 }
1495 }
1496
1497 std::string errorMessage;
1498 llvm::raw_string_ostream errorStream(errorMessage);
1499
1500 // Check that we matched an existing shape/dtype combination.
1501 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1502 !llvm::is_contained(allowedShapes, mmaShape)) {
1503 errorStream << "unimplemented variant for MMA shape <";
1504 llvm::interleaveComma(mmaShape, errorStream);
1505 errorStream << ">";
1506 return emitOpError(errorMessage);
1507 }
1508
1509 // Verify the operand types for segments of A, B, and C operands.
1510 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1511 for (const auto &iter : llvm::enumerate(
1512 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1513 auto spec = this->getODSOperandIndexAndLength(iter.index());
1514 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1515 operand_type_begin() + spec.first +
1516 spec.second);
1517 bool match = llvm::is_contained(iter.value(), operandTySeg);
1518
1519 if (!match) {
1520 errorStream << "Could not match types for the "
1521 << operandNames[iter.index()]
1522 << " operands; expected one of ";
1523 for (const auto &x : iter.value()) {
1524 errorStream << x.size() << "x" << x[0] << " ";
1525 }
1526 errorStream << "but got ";
1527 llvm::interleaveComma(operandTySeg, errorStream);
1528 return emitOpError(errorMessage);
1529 }
1530 }
1531
1532 // Check the result type
1533 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1534 return expectedResultType == getResult().getType();
1535 })) {
1536 errorStream
1537 << "Could not match allowed types for the result; expected one of ";
1538 llvm::interleaveComma(expectedResult, errorStream);
1539 errorStream << " but got " << getResult().getType();
1540 return emitOpError(errorMessage);
1541 }
1542
1543 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1544 // attribute.
1545 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1546 isInt8PtxType(*getMultiplicandAPtxType())) {
1547 if (!getIntOverflowBehavior())
1548 return emitOpError("op requires " +
1549 getIntOverflowBehaviorAttrName().strref() +
1550 " attribute");
1551 }
1552
1553 // Validate sparse metadata type (should be i32)
1554 if (!getSparseMetadata().getType().isInteger(32)) {
1555 return emitOpError() << "sparse metadata must be i32 type";
1556 }
1557
1558 // Validate sparsity selector type (should be i32)
1559 if (!getSparsitySelector().getType().isInteger(32)) {
1560 return emitOpError() << "sparsity selector must be i32 type";
1561 }
1562
1563 return success();
1564}
1565
1566//===----------------------------------------------------------------------===//
1567// MMA Block Scale Operations - Shared Helpers
1568//===----------------------------------------------------------------------===//
1569
1570namespace {
1571// Shared structure for MMA operand fragments (A, B, C)
1572struct MMAOperandFragment {
1573 StringRef operandName;
1574 StringRef ptxTypeAttr;
1575 SmallVector<Value, 4> regs;
1576 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1577 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1578};
1579} // namespace
1580
1581// Helper to print operand list in the format: name[operands]
1582static void printOperandList(OpAsmPrinter &p, StringRef name,
1583 ArrayRef<Value> operands) {
1584 p << " " << name << "[";
1585 p.printOperands(operands);
1586 p << "]";
1587}
1588
1589// Helper to parse operand list in the format: name[operands]
1590static LogicalResult
1591parseMmaOperand(OpAsmParser &parser, StringRef operandName,
1593 if (parser.parseKeyword(operandName).failed())
1594 return failure();
1596 .failed())
1597 return failure();
1598 return success();
1599}
1600
1601// Helper to process operand fragments and determine which attributes can be
1602// inferred
1603template <typename Op>
1604static void
1605processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
1606 SmallVectorImpl<Type> &regTypes,
1607 SmallVectorImpl<StringRef> &ignoreAttrNames) {
1608 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1609 auto &frag = frags[fragIdx];
1610 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1611 for (auto operandIdx = varOperandSpec.first;
1612 operandIdx < varOperandSpec.first + varOperandSpec.second;
1613 operandIdx++) {
1614 frag.regs.push_back(op.getOperand(operandIdx));
1615 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1616 regTypes.push_back(op.getOperand(operandIdx).getType());
1617 }
1618 }
1619 if (fragIdx < 2) {
1620 regTypes.push_back(frag.regs[0].getType());
1621 }
1622 std::optional<MMATypes> inferredType =
1623 MmaOp::inferOperandMMAType(regTypes.back(),
1624 /*isAccumulator=*/fragIdx >= 2);
1625 if (inferredType)
1626 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1627 }
1628}
1629
1630// Helper to parse type signature: (A_type, B_type, C_type)
1631static LogicalResult
1633 SmallVectorImpl<Type> &operandTypes) {
1634 if (parser.parseColon().failed() || parser.parseLParen().failed())
1635 return failure();
1636
1637 auto typeParser = [&]() {
1638 Type ty;
1639 if (parser.parseType(ty).failed())
1640 return failure();
1641 operandTypes.push_back(ty);
1642 return success();
1643 };
1644 if (parser.parseCommaSeparatedList(typeParser))
1645 return failure();
1646
1647 if (operandTypes.size() != 3)
1648 return parser.emitError(parser.getCurrentLocation(),
1649 "expected exactly 3 types");
1650
1651 return parser.parseRParen();
1652}
1653
1654// Helper to infer and set multiplicand PTX type attributes
1655static void
1657 const SmallVectorImpl<Type> &operandTypes) {
1658 if (!attrs.get("multiplicandAPtxType")) {
1659 if (auto inferredType =
1660 MmaOp::inferOperandMMAType(operandTypes[0], false)) {
1661 attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1662 }
1663 }
1664 if (!attrs.get("multiplicandBPtxType")) {
1665 if (auto inferredType =
1666 MmaOp::inferOperandMMAType(operandTypes[1], false)) {
1667 attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1668 }
1669 }
1670}
1671
1672// Helper to add common block scale properties
1673template <typename OpType>
1676 ScaleVecSize scaleVecSize,
1677 BlockScaleFormat blockScaleFormat,
1678 MMABlockScaleKind kind) {
1679 MLIRContext *ctx = builder.getContext();
1680 auto &properties = result.getOrAddProperties<typename OpType::Properties>();
1681 properties.setShape(
1682 builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1683 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1684 properties.setBlockScaleFormat(
1685 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1686 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1687}
1688
1689// Helper to infer and add multiplicand PTX types to builder
1692 ValueRange operandB,
1693 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1694 if (multiplicandPtxTypes) {
1695 result.addAttribute("multiplicandAPtxType",
1696 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1697 result.addAttribute("multiplicandBPtxType",
1698 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1699 } else {
1700 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1701 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1702 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1703 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1704 }
1705}
1706
1707// Template helper for common accumPtxType/resultPtxType implementation
1708template <typename OpTy>
1709static MMATypes inferPtxTypeFromResult(OpTy op) {
1710 return *MmaOp::inferOperandMMAType(
1711 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1712 /*isAccumulator=*/true);
1713}
1714
1715//===----------------------------------------------------------------------===//
1716// MmaBlockScaleOp
1717//===----------------------------------------------------------------------===//
1718
1719void MmaBlockScaleOp::print(OpAsmPrinter &p) {
1720 SmallVector<Type, 4> regTypes;
1721 std::array<MMAOperandFragment, 3> frags{
1722 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1723 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1724 MMAOperandFragment("C", "")};
1725 SmallVector<StringRef, 4> ignoreAttrNames{
1726 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1727
1728 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1729
1730 // Print A, B, C operands
1731 for (const auto &frag : frags)
1732 printOperandList(p, frag.operandName, frag.regs);
1733
1734 // Print scale operands
1735 printOperandList(p, "scaleA",
1736 {getScaleAData(), getByteIdA(), getThreadIdA()});
1737 printOperandList(p, "scaleB",
1738 {getScaleBData(), getByteIdB(), getThreadIdB()});
1739
1740 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1741
1742 // Print type signature
1743 p << " : (";
1744 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1745 frags[1].regs[0].getType(),
1746 frags[2].regs[0].getType()},
1747 p);
1748 p << ")";
1749 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1750}
1751
1752ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
1754 struct LocalOperandFragment {
1755 std::optional<MMATypes> elemtype;
1756 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1757 };
1758
1759 Builder &builder = parser.getBuilder();
1760 std::array<LocalOperandFragment, 3> frags;
1761 NamedAttrList namedAttributes;
1762
1763 // Parse A[...] B[...] C[...]
1764 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1765 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1766 parseMmaOperand(parser, "C", frags[2].regs).failed())
1767 return failure();
1768
1769 // Parse scale operands: scaleA[...] scaleB[...]
1770 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
1771 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
1772 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
1773 return failure();
1774
1775 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1776 return failure();
1777
1778 // Parse type signature
1779 SmallVector<Type, 3> operandTypes;
1780 if (parseMmaTypeSignature(parser, operandTypes).failed())
1781 return failure();
1782
1783 // Parse result type
1784 SmallVector<Type, 1> resultTypes;
1785 if (parser.parseArrowTypeList(resultTypes).failed())
1786 return failure();
1787
1788 // Infer element types and resolve operands
1789 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
1790 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1791 /*isAccumulator=*/idx >= 2);
1792 if (parser
1793 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
1794 result.operands)
1795 .failed())
1796 return failure();
1797 }
1798
1799 // Resolve scale operands
1800 SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
1801 builder.getI16Type()};
1802 if (parser
1803 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
1804 result.operands)
1805 .failed() ||
1806 parser
1807 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
1808 result.operands)
1809 .failed())
1810 return failure();
1811
1812 // Add attributes
1813 result.addAttributes(namedAttributes);
1814 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
1815 operandTypes);
1816
1817 result.addTypes(resultTypes);
1818 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1819 builder.getDenseI32ArrayAttr({
1820 static_cast<int32_t>(frags[0].regs.size()),
1821 static_cast<int32_t>(frags[1].regs.size()),
1822 static_cast<int32_t>(frags[2].regs.size()),
1823 1, // scaleAData
1824 1, // byteIdA
1825 1, // threadIdA
1826 1, // scaleBData
1827 1, // byteIdB
1828 1 // threadIdB
1829 }));
1830 return success();
1831}
1832
1833void MmaBlockScaleOp::build(
1834 OpBuilder &builder, OperationState &result, Type resultType,
1835 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1836 Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
1837 Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
1838 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1839 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1840 MMABlockScaleKind kind) {
1841 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1842
1844 blockScaleFormat, kind);
1845
1846 result.addOperands(operandA);
1847 result.addOperands(operandB);
1848 result.addOperands(operandC);
1849 result.addOperands(
1850 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1851
1852 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
1853 multiplicandPtxTypes);
1854
1855 result.addTypes(resultType);
1856 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1857 builder.getDenseI32ArrayAttr({
1858 static_cast<int32_t>(operandA.size()),
1859 static_cast<int32_t>(operandB.size()),
1860 static_cast<int32_t>(operandC.size()),
1861 1, // scaleAData
1862 1, // byteIdA
1863 1, // threadIdA
1864 1, // scaleBData
1865 1, // byteIdB
1866 1 // threadIdB
1867 }));
1868}
1869
1870NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
1871 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1872 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1873
1875 // Add A, B, C operands
1876 for (Value operand : curOp.getOperandA())
1877 args.push_back(mt.lookupValue(operand));
1878 for (Value operand : curOp.getOperandB())
1879 args.push_back(mt.lookupValue(operand));
1880 for (Value operand : curOp.getOperandC())
1881 args.push_back(mt.lookupValue(operand));
1882
1883 // Add scale operands
1884 args.push_back(mt.lookupValue(curOp.getScaleAData()));
1885 args.push_back(mt.lookupValue(curOp.getByteIdA()));
1886 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
1887 args.push_back(mt.lookupValue(curOp.getScaleBData()));
1888 args.push_back(mt.lookupValue(curOp.getByteIdB()));
1889 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
1890
1891 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1892 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1893 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1894 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
1895 curOp.getBlockScaleFormat(), curOp.getKind());
1896
1897 return {intId, args};
1898}
1899
1900LogicalResult MmaBlockScaleOp::verify() {
1901 LogicalResult result = success();
1902 int m = getShape().getM();
1903 int n = getShape().getN();
1904 int k = getShape().getK();
1905
1906 if (m == 16 && n == 8 && k == 64) {
1907 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1908 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1910 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1911 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1912 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1914 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1915 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1917 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1918 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1919 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1920 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1921 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1922 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
1923 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
1924 "attributes for mma.m16n8k64.mxf4nvf4");
1925 } else {
1926 result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
1927 }
1928 } else if (m == 16 && n == 8 && k == 32) {
1929 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
1930 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
1931 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
1932 result =
1933 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
1934 "attributes for mma.m16n8k32");
1935 } else {
1936 result = emitOpError("unsupported Geom for mma with block scaling");
1937 }
1938 return result;
1939}
1940
1941//===----------------------------------------------------------------------===//
1942// MmaSpBlockScaleOp
1943//===----------------------------------------------------------------------===//
1944
1945void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
1946 SmallVector<Type, 4> regTypes;
1947 std::array<MMAOperandFragment, 3> frags{
1948 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1949 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1950 MMAOperandFragment("C", "")};
1951 SmallVector<StringRef, 4> ignoreAttrNames{
1952 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
1953
1954 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1955
1956 // Print A, B, C operands
1957 for (const auto &frag : frags)
1958 printOperandList(p, frag.operandName, frag.regs);
1959
1960 // Print sparse-specific operands
1961 printOperandList(p, "sparseMetadata", {getSparseMetadata()});
1962 printOperandList(p, "selector", {getSparsitySelector()});
1963
1964 // Print scale operands
1965 printOperandList(p, "scaleA",
1966 {getScaleAData(), getByteIdA(), getThreadIdA()});
1967 printOperandList(p, "scaleB",
1968 {getScaleBData(), getByteIdB(), getThreadIdB()});
1969
1970 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1971
1972 // Print type signature
1973 p << " : (";
1974 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1975 frags[1].regs[0].getType(),
1976 frags[2].regs[0].getType()},
1977 p);
1978 p << ")";
1979 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1980}
1981
1982ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
1984 struct LocalOperandFragment {
1985 std::optional<MMATypes> elemtype;
1986 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1987 };
1988
1989 Builder &builder = parser.getBuilder();
1990 std::array<LocalOperandFragment, 3> frags;
1991 NamedAttrList namedAttributes;
1992
1993 // Parse A[...] B[...] C[...]
1994 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1995 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1996 parseMmaOperand(parser, "C", frags[2].regs).failed())
1997 return failure();
1998
1999 // Parse sparse-specific operands
2001 selectorOperands;
2002 if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
2003 parseMmaOperand(parser, "selector", selectorOperands).failed())
2004 return failure();
2005
2006 // Parse scale operands
2007 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
2008 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
2009 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
2010 return failure();
2011
2012 if (parser.parseOptionalAttrDict(namedAttributes).failed())
2013 return failure();
2014
2015 // Parse type signature
2016 SmallVector<Type, 3> operandTypes;
2017 if (parseMmaTypeSignature(parser, operandTypes).failed())
2018 return failure();
2019
2020 // Parse result type
2021 SmallVector<Type, 1> resultTypes;
2022 if (parser.parseArrowTypeList(resultTypes).failed())
2023 return failure();
2024
2025 // Infer element types and resolve operands
2026 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
2027 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2028 /*isAccumulator=*/idx >= 2);
2029 if (parser
2030 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
2031 result.operands)
2032 .failed())
2033 return failure();
2034 }
2035
2036 // Resolve sparse metadata and selector
2037 Type i32Type = builder.getI32Type();
2038 if (parser
2039 .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
2040 result.operands)
2041 .failed() ||
2042 parser
2043 .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
2044 result.operands)
2045 .failed())
2046 return failure();
2047
2048 // Resolve scale operands
2049 SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
2050 builder.getI16Type()};
2051 if (parser
2052 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
2053 result.operands)
2054 .failed() ||
2055 parser
2056 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
2057 result.operands)
2058 .failed())
2059 return failure();
2060
2061 // Add attributes
2062 result.addAttributes(namedAttributes);
2063 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
2064 operandTypes);
2065
2066 // orderedMetadata is mandatory
2067 if (!result.attributes.get("orderedMetadata"))
2068 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2069
2070 result.addTypes(resultTypes);
2071 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2072 builder.getDenseI32ArrayAttr({
2073 static_cast<int32_t>(frags[0].regs.size()),
2074 static_cast<int32_t>(frags[1].regs.size()),
2075 static_cast<int32_t>(frags[2].regs.size()),
2076 1, // sparseMetadata
2077 1, // sparsitySelector
2078 1, // scaleAData
2079 1, // byteIdA
2080 1, // threadIdA
2081 1, // scaleBData
2082 1, // byteIdB
2083 1 // threadIdB
2084 }));
2085 return success();
2086}
2087
2088void MmaSpBlockScaleOp::build(
2089 OpBuilder &builder, OperationState &result, Type resultType,
2090 ValueRange operandA, ValueRange operandB, ValueRange operandC,
2091 Value sparseMetadata, Value sparsitySelector, Value scaleAData,
2092 Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
2093 Value threadIdB, ArrayRef<int64_t> shape,
2094 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2095 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2096 MMABlockScaleKind kind) {
2097 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
2098
2100 builder, result, shape, scaleVecSize, blockScaleFormat, kind);
2101 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2102
2103 result.addOperands(operandA);
2104 result.addOperands(operandB);
2105 result.addOperands(operandC);
2106 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2107 threadIdA, scaleBData, byteIdB, threadIdB});
2108
2109 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
2110 multiplicandPtxTypes);
2111
2112 result.addTypes(resultType);
2113 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2114 builder.getDenseI32ArrayAttr({
2115 static_cast<int32_t>(operandA.size()),
2116 static_cast<int32_t>(operandB.size()),
2117 static_cast<int32_t>(operandC.size()),
2118 1, // sparseMetadata
2119 1, // sparsitySelector
2120 1, // scaleAData
2121 1, // byteIdA
2122 1, // threadIdA
2123 1, // scaleBData
2124 1, // byteIdB
2125 1 // threadIdB
2126 }));
2127}
2128
2129NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
2130 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2131 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2132
2134 // Add A, B, C operands
2135 for (Value operand : curOp.getOperandA())
2136 args.push_back(mt.lookupValue(operand));
2137 for (Value operand : curOp.getOperandB())
2138 args.push_back(mt.lookupValue(operand));
2139 for (Value operand : curOp.getOperandC())
2140 args.push_back(mt.lookupValue(operand));
2141
2142 // Add sparse metadata and selector
2143 args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
2144 args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
2145
2146 // Add scale operands
2147 args.push_back(mt.lookupValue(curOp.getScaleAData()));
2148 args.push_back(mt.lookupValue(curOp.getByteIdA()));
2149 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
2150 args.push_back(mt.lookupValue(curOp.getScaleBData()));
2151 args.push_back(mt.lookupValue(curOp.getByteIdB()));
2152 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
2153
2154 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2155 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2156 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2157 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
2158 curOp.getBlockScaleFormat(), curOp.getKind());
2159
2160 return {intId, args};
2161}
2162
2163LogicalResult MmaSpBlockScaleOp::verify() {
2164 // Check that orderedMetadata is present
2165 if (!getOrderedMetadata()) {
2166 return emitOpError("'orderedMetadata' attribute is mandatory");
2167 }
2168
2169 LogicalResult result = success();
2170 int m = getShape().getM();
2171 int n = getShape().getN();
2172 int k = getShape().getK();
2173
2174 if (m == 16 && n == 8 && k == 128) {
2175 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2176 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2178 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2179 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2180 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2182 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2183 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2185 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2186 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2187 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2188 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2189 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2190 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
2191 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
2192 "attributes for mma.m16n8k128.mxf4nvf4");
2193 } else {
2194 result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
2195 }
2196 } else if (m == 16 && n == 8 && k == 64) {
2197 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2198 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2199 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2200 result =
2201 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
2202 "attributes for mma.m16n8k64");
2203 } else {
2204 result = emitOpError("unsupported Geom for sparse mma with block scaling");
2205 }
2206 return result;
2207}
2208
2209LogicalResult ShflOp::verify() {
2210 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2211
2212 auto verifyTypeError = [&](Twine desc, Type expectedType,
2213 Type actualType) -> LogicalResult {
2214 return emitOpError("expected " + desc + " to be of type ")
2215 << expectedType << " but got " << actualType << " instead";
2216 };
2217
2218 if (returnStructType) {
2219 if (!getReturnValueAndIsValid())
2220 return emitOpError("\"return_value_and_is_valid\" attribute must be "
2221 "specified when the return type is a struct type");
2222
2223 if (returnStructType.getBody().size() != 2)
2224 return emitOpError("expected return type to be a two-element struct");
2225
2226 llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
2227 auto resultType = returnStruct[0];
2228 if (resultType != getVal().getType())
2229 return verifyTypeError("first element in the returned struct",
2230 getVal().getType(), resultType);
2231
2232 auto predicateType = returnStruct[1];
2233 if (!predicateType.isInteger(1))
2234 return verifyTypeError("second element in the returned struct",
2235 mlir::IntegerType::get(getContext(), 1),
2236 predicateType);
2237 } else {
2238 if (getReturnValueAndIsValid())
2239 return emitOpError("expected return type to be a two-element struct");
2240
2241 if (getType() != getVal().getType())
2242 return verifyTypeError("return type", getVal().getType(), getType());
2243 }
2244 return success();
2245}
2246
2247std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
2248 NVVM::MMAFrag frag, int nRow,
2249 int nCol,
2250 MLIRContext *context) {
2251 unsigned numberElements = 0;
2252 Type elementType;
2253 OpBuilder builder(context);
2254 Type f16x2 = VectorType::get(2, builder.getF16Type());
2255 if (type == NVVM::MMATypes::f16) {
2256 elementType = f16x2;
2257 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2258 numberElements = 8;
2259 else
2260 numberElements = 4;
2261 } else if (type == NVVM::MMATypes::f32) {
2262 elementType = builder.getF32Type();
2263 numberElements = 8;
2264 } else if (type == NVVM::MMATypes::f64) {
2265 elementType = builder.getF64Type();
2266 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2267 numberElements = 1;
2268 else
2269 numberElements = 2;
2270 } else if (type == NVVM::MMATypes::tf32) {
2271 elementType = builder.getI32Type();
2272 numberElements = 4;
2273 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2274 elementType = builder.getI32Type();
2275 int parallelSize = 0;
2276 if (frag == NVVM::MMAFrag::a)
2277 parallelSize = nRow;
2278 if (frag == NVVM::MMAFrag::b)
2279 parallelSize = nCol;
2280
2281 // m == 16 && n == 16 && k == 16
2282 if (parallelSize == 16)
2283 numberElements = 2;
2284 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
2285 else if (parallelSize == 8)
2286 numberElements = 1;
2287 else if (parallelSize == 32)
2288 numberElements = 4;
2289 } else if (type == NVVM::MMATypes::s32) {
2290 elementType = builder.getI32Type();
2291 numberElements = 8;
2292 }
2293 assert(numberElements != 0 && elementType != nullptr);
2294 return std::make_pair(elementType, numberElements);
2295}
2296
2297static std::pair<mlir::Type, unsigned>
2298inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
2299 int k, MLIRContext *context) {
2300 int nRow, nCol;
2301 if (frag == NVVM::MMAFrag::a) {
2302 nRow = m;
2303 nCol = k;
2304 } else if (frag == NVVM::MMAFrag::b) {
2305 nRow = k;
2306 nCol = n;
2307 } else {
2308 nRow = m;
2309 nCol = n;
2310 }
2311 assert(nRow && nCol);
2312 return inferMMAType(type, frag, nRow, nCol, context);
2313}
2314
2315LogicalResult NVVM::WMMALoadOp::verify() {
2316 unsigned addressSpace =
2317 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2318 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2319 addressSpace != NVVMMemorySpace::Shared)
2320 return emitOpError("expected source pointer in memory "
2321 "space 0, 1, 3");
2322
2323 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2324 getEltype(), getFrag()) == 0)
2325 return emitOpError() << "invalid attribute combination";
2326 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2327 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
2328 // Special case for f64 fragments
2329 Type f64Ty = Float64Type::get(getContext());
2330 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2331 if (getType() != f64Ty)
2332 return emitOpError("expected destination type to be f64");
2333 return success();
2334 }
2335 // Everything else is a struct
2336 Type dstType = LLVM::LLVMStructType::getLiteral(
2337 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
2338 if (getType() != dstType)
2339 return emitOpError("expected destination type is a structure of ")
2340 << typeInfo.second << " elements of type " << typeInfo.first;
2341 return success();
2342}
2343
2344LogicalResult NVVM::WMMAStoreOp::verify() {
2345 unsigned addressSpace =
2346 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2347 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2348 addressSpace != NVVMMemorySpace::Shared)
2349 return emitOpError("expected operands to be a source pointer in memory "
2350 "space 0, 1, 3");
2351
2352 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2353 getEltype()) == 0)
2354 return emitOpError() << "invalid attribute combination";
2355 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2356 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2357 if (getArgs().size() != typeInfo.second)
2358 return emitOpError() << "expected " << typeInfo.second << " data operands";
2359 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
2360 return operands.getType() != typeInfo.first;
2361 }))
2362 return emitOpError() << "expected data operands of type " << typeInfo.first;
2363 return success();
2364}
2365
2366LogicalResult NVVM::WMMAMmaOp::verify() {
2367 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
2368 getLayoutB(), getEltypeA(),
2369 getEltypeB()) == 0)
2370 return emitOpError() << "invalid attribute combination";
2371 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
2372 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
2373 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
2374 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
2375 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
2376 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2377 SmallVector<Type, 32> arguments;
2378 arguments.append(typeInfoA.second, typeInfoA.first);
2379 arguments.append(typeInfoB.second, typeInfoB.first);
2380 arguments.append(typeInfoC.second, typeInfoC.first);
2381 unsigned numArgs = arguments.size();
2382 if (getArgs().size() != numArgs)
2383 return emitOpError() << "expected " << numArgs << " arguments";
2384 for (unsigned i = 0; i < numArgs; i++) {
2385 if (getArgs()[i].getType() != arguments[i])
2386 return emitOpError() << "expected argument " << i << " to be of type "
2387 << arguments[i];
2388 }
2389 Type dstType = LLVM::LLVMStructType::getLiteral(
2390 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
2391 if (getType() != dstType)
2392 return emitOpError("expected destination type is a structure of ")
2393 << typeInfoC.second << " elements of type " << typeInfoC.first;
2394 return success();
2395}
2396
2397LogicalResult NVVM::LdMatrixOp::verify() {
2398 uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
2399 if (m == 8 && n == 8) {
2400 if (num != 1 && num != 2 && num != 4) {
2401 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
2402 "matrix");
2403 }
2404 if (getEltType() != LdStMatrixEltType::B16) {
2405 return emitOpError("expected element type to be b16 for 8x8 matrix");
2406 }
2407 } else if (m == 8 && n == 16) {
2408 if (num != 1 && num != 2 && num != 4) {
2409 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
2410 "matrix");
2411 }
2412 if (getLayout() != MMALayout::row) {
2413 return emitOpError("expected layout to be row for 8x16 matrix");
2414 }
2415 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2416 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2417 return emitOpError("expected element type to be b8x16.b4x16_p64 or "
2418 "b8x16.b6x16_p32 for 8x16 matrix");
2419 }
2420 } else if (m == 16 && n == 16) {
2421 if (num != 1 && num != 2) {
2422 return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
2423 "matrix");
2424 }
2425 if (getLayout() != MMALayout::col) {
2426 return emitOpError("expected layout to be col for 16x16 matrix");
2427 }
2428 if (getEltType() != LdStMatrixEltType::B8 &&
2429 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2430 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2431 return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
2432 "b8x16.b6x16_p32 for 16x16 matrix");
2433 }
2434 } else {
2435 return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
2436 }
2437
2438 Type i32 = IntegerType::get(getContext(), 32);
2439 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2440 if (numElements == 1 && getType() != i32)
2441 return emitOpError("expected destination type is i32");
2442 if (numElements == 2 || numElements == 4) {
2443 Type dstType = LLVM::LLVMStructType::getLiteral(
2444 getContext(), SmallVector<Type>(numElements, i32));
2445 if (getType() != dstType)
2446 return emitOpError("expected destination type is a structure of ")
2447 << numElements << " elements of type i32";
2448 }
2449
2450 return success();
2451}
2452
2453LogicalResult NVVM::StMatrixOp::verify() {
2454 int numMatrix = getSources().size();
2455 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2456 return emitOpError("expected num attribute to be 1, 2 or 4");
2457
2458 int m = getShape().getM(), n = getShape().getN();
2459 if (m == 8 && n == 8) {
2460 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2461 return emitOpError("expected element type to be B16 for 8x8 matrix");
2462 }
2463 } else if (m == 16 && n == 8) {
2464 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2465 return emitOpError("expected element type to be B8 for 16x8 matrix");
2466 }
2467 if (getLayout() != NVVM::MMALayout::col) {
2468 return emitOpError("expected layout to be col for 16x8 matrix");
2469 }
2470 } else {
2471 return emitOpError("expected shape to be 8x8 or 16x8");
2472 }
2473
2474 return success();
2475}
2476
2477static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
2478 if (typeA == NVVM::WGMMATypes::tf32)
2479 return 8;
2480 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2481 return 16;
2482 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2483 return 32;
2484 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2485 return 32;
2486 if (typeA == NVVM::WGMMATypes::b1)
2487 return 256;
2488 return failure();
2489}
2490
2491static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
2492 NVVM::WGMMATypes typeA,
2493 NVVM::WGMMATypes typeB) {
2494 switch (typeA) {
2495 case NVVM::WGMMATypes::f16:
2496 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2497 typeB == NVVM::WGMMATypes::f16)
2498 return success();
2499 break;
2500 case NVVM::WGMMATypes::tf32:
2501 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2502 return success();
2503 break;
2504 case NVVM::WGMMATypes::u8:
2505 case NVVM::WGMMATypes::s8:
2506 if (typeD == NVVM::WGMMATypes::s32 &&
2507 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2508 return success();
2509 break;
2510 case NVVM::WGMMATypes::b1:
2511 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2512 return success();
2513 break;
2514 case NVVM::WGMMATypes::bf16:
2515 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2516 typeB == NVVM::WGMMATypes::bf16)
2517 return success();
2518 break;
2519 case NVVM::WGMMATypes::e4m3:
2520 case NVVM::WGMMATypes::e5m2:
2521 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2522 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2523 return success();
2524 break;
2525 case WGMMATypes::f32:
2526 case WGMMATypes::s32:
2527 llvm_unreachable("unsupported input types");
2528 break;
2529 }
2530 return failure();
2531}
2532
2533static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
2534 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
2535 72, 80, 88, 96, 104, 112, 120, 128,
2536 136, 144, 152, 160, 168, 176, 184, 192,
2537 200, 208, 216, 224, 232, 240, 248, 256};
2538 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
2539 80, 96, 112, 128, 144, 160,
2540 176, 192, 208, 224, 240, 256};
2541 switch (typeA) {
2542 case WGMMATypes::f16:
2543 case WGMMATypes::tf32:
2544 case WGMMATypes::bf16:
2545 case WGMMATypes::e4m3:
2546 case WGMMATypes::e5m2:
2547 if (llvm::is_contained(allowedN, sizeN))
2548 return success();
2549 break;
2550 case WGMMATypes::u8:
2551 case WGMMATypes::s8:
2552 case WGMMATypes::b1:
2553 if (llvm::is_contained(allowedNshort, sizeN))
2554 return success();
2555 break;
2556 case WGMMATypes::f32:
2557 case WGMMATypes::s32:
2558 llvm_unreachable("unsupported input types");
2559 break;
2560 }
2561 return failure();
2562}
2563
2564LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2565 Value outValue = getResults();
2566 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
2567 if (!stype)
2568 return emitOpError() << "expected results to be struct";
2569 int outputSize = stype.getBody().size();
2570 WGMMATypes typeD = getTypeD();
2571 WGMMATypes typeA = getTypeA();
2572 WGMMATypes typeB = getTypeB();
2573
2574 for (Type t : stype.getBody()) {
2575 if (t != stype.getBody().front())
2576 return emitOpError()
2577 << "all elements in struct must be same type but there is " << t;
2578 }
2579
2580 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2581 typeD != WGMMATypes::s32) {
2582 return emitOpError() << "does not support the given output type "
2583 << NVVM::stringifyWGMMATypes(typeD);
2584 }
2585 if (typeD == WGMMATypes::s32 &&
2586 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2587 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
2588 }
2589
2590 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
2591 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
2592 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
2593 << NVVM::stringifyWGMMATypes(typeB)
2594 << ", it is not supported.";
2595 }
2596
2597 // Check M
2598 if (getShape().getM() != 64)
2599 return emitOpError() << "shape 'm' must be 64";
2600
2601 // Check K
2602 FailureOr<int> allowedK = getAllowedSizeK(typeA);
2603 if (failed(allowedK) || allowedK.value() != getShape().getK())
2604 return emitOpError() << "shape 'k' must be " << allowedK.value()
2605 << " for input type "
2606 << NVVM::stringifyWGMMATypes(typeA);
2607
2608 // Check N
2609 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
2610 return emitOpError() << "has input type "
2611 << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
2612 << getShape().getN() << ", it is not supported.";
2613 }
2614
2615 // Check transpose (only available for f16/bf16)
2616 // Matrices A should be stored in row-major and B in column-major.
2617 // Only f16/bf16 matrices can be stored in either column-major or row-major
2618 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
2619 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2620 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2621 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2622 return emitOpError()
2623 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
2624 << " and layout_b = " << stringifyMMALayout(getLayoutB())
2625 << " for input types " << stringifyWGMMATypes(typeA) << " and "
2626 << stringifyWGMMATypes(typeB)
2627 << " requires transpose. However, this is only supported for: "
2628 << stringifyMMATypes(MMATypes::f16) << " and "
2629 << stringifyMMATypes(MMATypes::bf16);
2630 }
2631
2632 // Check result registers
2633 int expectedOutput = 0;
2634 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2635 expectedOutput = getShape().getN() / 2;
2636 if (typeD == WGMMATypes::f16)
2637 expectedOutput = getShape().getN() / 4;
2638 if (outputSize != expectedOutput) {
2639 return emitOpError() << "results " << expectedOutput
2640 << ", however output struct has " << outputSize
2641 << " elements";
2642 }
2643 // Check satfinite (only available for s32 accumulator)
2644 if (typeD != WGMMATypes::s32 &&
2645 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2646 NVVM::MMAIntOverflow::satfinite) {
2647 return emitOpError()
2648 << " `satfinite` can be only used with s32 accumulator, however "
2649 "the current accumulator is "
2650 << NVVM::stringifyWGMMATypes(typeD);
2651 }
2652
2653 return success();
2654}
2655
2656std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2657
2658 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
2659 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2660
2661 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2662
2663 int expectedOutputRegisters = 0;
2664 if (getTypeD() == WGMMATypes::f16)
2665 expectedOutputRegisters = getShape().getN() / 4;
2666 else
2667 expectedOutputRegisters = getShape().getN() / 2;
2668
2669 std::string ptx;
2670 llvm::raw_string_ostream ss(ptx);
2671
2672 ss << "{\n"
2673 ".reg .pred p;\n"
2674 "setp.ne.b32 p, $"
2675 << ((expectedOutputRegisters * 2) + 2)
2676 << ", 0;\n"
2677 "wgmma.mma_async.sync.aligned.m"
2678 << m << "n" << n << "k" << k << "." << outputTypeName << "."
2679 << stringifyWGMMATypes(getTypeA()) << "."
2680 << stringifyWGMMATypes(getTypeB());
2681 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2682 NVVM::MMAIntOverflow::satfinite)
2683 ss << ".satfinite";
2684 ss << " {";
2685 int regCnt = 0;
2686 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2687 ss << "$" << regCnt;
2688 if (regCnt != expectedOutputRegisters - 1)
2689 ss << ", ";
2690 }
2691
2692 ss << "},";
2693 // Need to map read/write registers correctly.
2694 regCnt = (regCnt * 2);
2695 ss << " $" << (regCnt) << ","
2696 << " $" << (regCnt + 1) << ","
2697 << " p";
2698 if (getTypeD() != WGMMATypes::s32) {
2699 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
2700 }
2701 // Don't add transpose parameters unless needed.
2702 if (isF16) {
2703 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
2704 }
2705 ss << ";\n"
2706 << "}\n";
2707 return ptx;
2708}
2709
2710bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2711 RewriterBase &rewriter,
2712 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2713 &asmValues) {
2714 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2715 if (getResults())
2716 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
2717 if (getInouts())
2718 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
2719 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
2720 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
2721 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
2723 if (getTypeD() != WGMMATypes::s32) {
2724 asmValues.push_back(
2725 {makeConstantI32(rewriter,
2726 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2728 asmValues.push_back(
2729 {makeConstantI32(rewriter,
2730 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2732 }
2733 if (isF16) {
2734 asmValues.push_back(
2735 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
2737 asmValues.push_back(
2738 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
2740 }
2741 return true; // Has manual mapping
2742}
2743
2744LogicalResult NVVM::FenceSyncRestrictOp::verify() {
2745 if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
2746 getOrder() != NVVM::MemOrderKind::RELEASE)
2747 return emitOpError("only acquire and release semantics are supported");
2748 return success();
2749}
2750
2751LogicalResult NVVM::FenceProxyOp::verify() {
2752 if (getKind() == NVVM::ProxyKind::TENSORMAP)
2753 return emitOpError() << "tensormap proxy is not a supported proxy kind";
2754 if (getKind() == NVVM::ProxyKind::GENERIC)
2755 return emitOpError() << "generic proxy not a supported proxy kind";
2756 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2757 return emitOpError() << "async_shared fence requires space attribute";
2758 }
2759 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2760 return emitOpError() << "only async_shared fence can have space attribute";
2761 }
2762 return success();
2763}
2764
2765LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2766 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2767 return emitOpError("uni-directional proxies only support generic for "
2768 "from_proxy attribute");
2769
2770 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2771 return emitOpError("uni-directional proxies only support tensormap "
2772 "for to_proxy attribute");
2773 return success();
2774}
2775
2776LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2777 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2778 return emitOpError("uni-directional proxies only support generic for "
2779 "from_proxy attribute");
2780
2781 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2782 return emitOpError("uni-directional proxies only support tensormap "
2783 "for to_proxy attribute");
2784 return success();
2785}
2786
2787LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2788 if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
2789 getOrder() != NVVM::MemOrderKind::RELEASE)
2790 return emitOpError("only acquire and release semantics are supported");
2791
2792 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2793 return emitOpError("only generic is support for from_proxy attribute");
2794
2795 if (getToProxy() != NVVM::ProxyKind::async)
2796 return emitOpError("only async is supported for to_proxy attribute");
2797 return success();
2798}
2799
2800LogicalResult NVVM::SetMaxRegisterOp::verify() {
2801 if (getRegCount() % 8)
2802 return emitOpError("new register size must be multiple of 8");
2803 if (getRegCount() < 24 || getRegCount() > 256)
2804 return emitOpError("new register size must be in between 24 to 256");
2805 return success();
2806}
2807
2808LogicalResult NVVM::BarrierOp::verify() {
2809 if (getNumberOfThreads() && !getBarrierId())
2810 return emitOpError(
2811 "barrier id is missing, it should be set between 0 to 15");
2812
2813 if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
2814 return emitOpError("reduction are only available when id is 0");
2815
2816 if ((getReductionOp() && !getReductionPredicate()) ||
2817 (!getReductionOp() && getReductionPredicate()))
2818 return emitOpError("reduction predicate and reduction operation must be "
2819 "specified together");
2820
2821 return success();
2822}
2823
2824LogicalResult NVVM::Tcgen05CpOp::verify() {
2825 auto mc = getMulticast();
2826
2827 using SH = Tcgen05CpShape;
2828 using MC = Tcgen05CpMulticast;
2829 switch (getShape()) {
2830 case SH::SHAPE_128x256b:
2831 case SH::SHAPE_128x128b:
2832 case SH::SHAPE_4x256b:
2833 if (mc != MC::NONE)
2834 return emitError("Invalid multicast type for tcgen05.cp Op");
2835 break;
2836 case SH::SHAPE_64x128b:
2837 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2838 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
2839 "warpx2_02_13 for tcgen05.cp Op");
2840 break;
2841 case SH::SHAPE_32x128b:
2842 if (mc != MC::WARPX4)
2843 return emitError(
2844 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2845 break;
2846 }
2847 return success();
2848}
2849
2850LogicalResult NVVM::MatchSyncOp::verify() {
2851 if (getKind() == NVVM::MatchSyncKind::all) {
2852 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2853 if (!type || type.getBody().size() != 2 ||
2854 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2855 return emitOpError("match.sync 'all' returns a two element struct with "
2856 "first element as i32 and second element as i1");
2857 }
2858 } else {
2859 if (!getType().isInteger(32)) {
2860 return emitOpError("match.sync 'any' returns an i32");
2861 }
2862 }
2863 return success();
2864}
2865
2866LogicalResult NVVM::VoteSyncOp::verify() {
2867 if (getKind() == NVVM::VoteSyncKind::ballot) {
2868 if (!getType().isInteger(32)) {
2869 return emitOpError("vote.sync 'ballot' returns an i32");
2870 }
2871 } else {
2872 if (!getType().isInteger(1)) {
2873 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
2874 }
2875 }
2876 return success();
2877}
2878
2879LogicalResult NVVM::PrefetchOp::verify() {
2880 using MemSpace = NVVM::NVVMMemorySpace;
2881 using CacheLevel = NVVM::PrefetchCacheLevel;
2882
2883 unsigned addressSpace =
2884 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
2885 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2886 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2887
2888 if (getTensormap() && cacheLevel)
2889 return emitOpError("cannot specify both tensormap and cache level");
2890
2891 if (getTensormap()) {
2892 if (addressSpace != MemSpace::Generic &&
2893 addressSpace != MemSpace::Constant) {
2894 return emitOpError(
2895 "prefetch tensormap requires a generic or constant pointer");
2896 }
2897
2898 if (evictPriority) {
2899 return emitOpError(
2900 "prefetch tensormap does not support eviction priority");
2901 }
2902
2903 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2904 return emitOpError(
2905 "in_param_space can only be specified for a generic pointer");
2906 }
2907
2908 } else if (cacheLevel) {
2909 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
2910 addressSpace != MemSpace::Local) {
2911 return emitOpError("prefetch to cache level requires a generic, global, "
2912 "or local pointer");
2913 }
2914
2915 if (getUniform()) {
2916 if (*cacheLevel != CacheLevel::L1) {
2917 return emitOpError(
2918 "unsupported cache level, the only supported uniform "
2919 "cache level is L1");
2920 }
2921
2922 if (addressSpace != MemSpace::Generic) {
2923 return emitOpError(
2924 "prefetch to uniform cache requires a generic pointer");
2925 }
2926 }
2927
2928 if (evictPriority) {
2929 if (*cacheLevel != CacheLevel::L2)
2930 return emitOpError(
2931 "cache eviction priority supported only for cache level L2");
2932
2933 if (addressSpace != MemSpace::Global)
2934 return emitOpError("cache eviction priority requires a global pointer");
2935
2936 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
2937 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
2938 return emitOpError(
2939 "unsupported cache eviction priority, only evict_last and "
2940 "evict_normal are supported");
2941 }
2942
2943 if (getPredicate())
2944 return emitOpError("predicate supported only on prefetch tensormap");
2945
2946 } else {
2947 return emitOpError(
2948 "requires specification of either cache level or tensormap");
2949 }
2950
2951 return success();
2952}
2953
2954LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
2955 switch (getQueryType()) {
2956 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2957 if (!getType().isInteger(1))
2958 return emitOpError("is_canceled query type returns an i1");
2959 break;
2960 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2961 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2962 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2963 if (!getType().isInteger(32)) {
2964 return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
2965 "get_first_cta_id_z query types return an i32");
2966 }
2967 break;
2968 }
2969 return success();
2970}
2971
2972LogicalResult NVVM::ReduxOp::verify() {
2973 mlir::Type reduxType = getType();
2974
2975 if (!reduxType.isF32()) {
2976 if (getAbs())
2977 return emitOpError("abs attribute is supported only for f32 type");
2978 if (getNan())
2979 return emitOpError("nan attribute is supported only for f32 type");
2980 }
2981
2982 NVVM::ReduxKind kind = getKind();
2983 switch (kind) {
2984 case NVVM::ReduxKind::ADD:
2985 case NVVM::ReduxKind::AND:
2986 case NVVM::ReduxKind::OR:
2987 case NVVM::ReduxKind::XOR:
2988 case NVVM::ReduxKind::MAX:
2989 case NVVM::ReduxKind::MIN:
2990 case NVVM::ReduxKind::UMAX:
2991 case NVVM::ReduxKind::UMIN:
2992 if (!reduxType.isInteger(32))
2993 return emitOpError("'")
2994 << stringifyEnum(kind) << "' redux kind unsupported with "
2995 << reduxType << " type. Only supported type is 'i32'.";
2996 break;
2997 case NVVM::ReduxKind::FMIN:
2998 case NVVM::ReduxKind::FMAX:
2999 if (!reduxType.isF32())
3000 return emitOpError("'")
3001 << stringifyEnum(kind) << "' redux kind unsupported with "
3002 << reduxType << " type. Only supported type is 'f32'.";
3003 break;
3004 }
3005
3006 return success();
3007}
3008
3009/// Packs the given `field` into the `result`.
3010/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
3011static llvm::Value *
3012packValInto64Bits(llvm::IRBuilderBase &builder,
3013 llvm::Value *result, // the `result` (unset bits are zero)
3014 llvm::Value *field, // `field` to pack into `result`
3015 unsigned sizeInBits, // Size of `field` in bits
3016 unsigned start) { // Starting bit within `result`
3017 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3018
3019 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3020 if (mask != 0xffffffffu)
3021 field = builder.CreateAnd(field, builder.getInt32(mask));
3022
3023 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3024 field = builder.CreateShl(field, start);
3025
3026 return builder.CreateOr(result, field);
3027}
3028
3029void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
3031 llvm::IRBuilderBase &builder) {
3032 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3033 llvm::Value *smemDesc = builder.getInt64(0);
3034
3035 smemDesc = packValInto64Bits(builder, smemDesc,
3036 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
3037 smemDesc = packValInto64Bits(
3038 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3039 smemDesc = packValInto64Bits(
3040 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3041
3042 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
3043 smemDesc = packValInto64Bits(builder, smemDesc,
3044 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
3045 smemDesc = packValInto64Bits(
3046 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3047 smemDesc = packValInto64Bits(builder, smemDesc,
3048 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
3049
3050 mt.mapValue(thisOp.getRes()) = smemDesc;
3051}
3052
3053//===----------------------------------------------------------------------===//
3054// getPtx methods
3055//===----------------------------------------------------------------------===//
3056
3057std::string NVVM::MBarrierInitOp::getPtx() {
3058 bool isShared = isPtrInSharedCTASpace(getAddr());
3059 return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
3060 : std::string("mbarrier.init.b64 [%0], %1;");
3061}
3062
3063std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3064 bool isShared = isPtrInSharedCTASpace(getAddr());
3065 return isShared
3066 ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3067 : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3068}
3069
3070std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3071 bool isShared = isPtrInSharedCTASpace(getAddr());
3072 llvm::StringRef space = isShared ? ".shared" : "";
3073
3074 return llvm::formatv("{\n\t"
3075 ".reg .pred P1; \n\t"
3076 "LAB_WAIT: \n\t"
3077 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3078 "@P1 bra.uni DONE; \n\t"
3079 "bra.uni LAB_WAIT; \n\t"
3080 "DONE: \n\t"
3081 "}",
3082 space);
3083}
3084
3085//===----------------------------------------------------------------------===//
3086// getIntrinsicID/getIntrinsicIDAndArgs methods
3087//===----------------------------------------------------------------------===//
3088
3089mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
3090 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3091 auto thisOp = cast<NVVM::BarrierOp>(op);
3092 llvm::Value *barrierId = thisOp.getBarrierId()
3093 ? mt.lookupValue(thisOp.getBarrierId())
3094 : builder.getInt32(0);
3095 llvm::Intrinsic::ID id;
3096 llvm::SmallVector<llvm::Value *> args = {barrierId};
3097 if (thisOp.getNumberOfThreads()) {
3098 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3099 args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
3100 } else if (thisOp.getReductionOp()) {
3101 switch (*thisOp.getReductionOp()) {
3102 case NVVM::BarrierReduction::AND:
3103 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3104 break;
3105 case NVVM::BarrierReduction::OR:
3106 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3107 break;
3108 case NVVM::BarrierReduction::POPC:
3109 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3110 break;
3111 }
3112 args.push_back(builder.CreateICmpNE(
3113 mt.lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3114 } else {
3115 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3116 }
3117
3118 return {id, std::move(args)};
3119}
3120
3122PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3123 llvm::IRBuilderBase &builder) {
3124 auto thisOp = cast<NVVM::PMEventOp>(op);
3125 llvm::Type *i16Ty = llvm::Type::getInt16Ty(mt.getLLVMContext());
3126
3127 // With event-id, mask is generated as (1 << event-id)
3128 llvm::Value *maskVal;
3129 if (auto eventAttr = thisOp.getEventIdAttr()) {
3130 uint16_t mask = static_cast<uint16_t>(1u << eventAttr.getInt());
3131 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3132 } else {
3133 maskVal =
3134 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3135 }
3136
3137 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3138}
3139
3140mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
3141 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3142 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3143 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3144 llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3145 : llvm::Intrinsic::nvvm_mbarrier_init;
3146
3147 // Fill the Intrinsic Args
3149 args.push_back(mt.lookupValue(thisOp.getAddr()));
3150 args.push_back(mt.lookupValue(thisOp.getCount()));
3151
3152 return {id, std::move(args)};
3153}
3154
3155mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
3156 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3157 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3158 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3159 llvm::Intrinsic::ID id = isShared
3160 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3161 : llvm::Intrinsic::nvvm_mbarrier_inval;
3162
3163 return {id, {mt.lookupValue(thisOp.getAddr())}};
3164}
3165
3166mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
3167 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3168 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3169
3170 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3171 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3172 // bit-0: Space
3173 // bit-1: Scope
3174 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3175
3176 static constexpr llvm::Intrinsic::ID IDs[] = {
3177 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3178 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3179 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3180 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3181
3182 // Fill the Intrinsic Args
3184 args.push_back(mt.lookupValue(thisOp.getAddr()));
3185 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3186
3187 return {IDs[index], std::move(args)};
3188}
3189
3190mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
3191 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3192 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3193
3194 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3195 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3196 // bit-0: Space
3197 // bit-1: Scope
3198 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3199
3200 static constexpr llvm::Intrinsic::ID IDs[] = {
3201 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3202 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3203 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3204 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3205
3206 // Fill the Intrinsic Args
3208 args.push_back(mt.lookupValue(thisOp.getAddr()));
3209 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3210
3211 return {IDs[index], std::move(args)};
3212}
3213
3214mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
3215 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3216 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3217
3218 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3219 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3220 // bit-0: Space
3221 // bit-1: Scope
3222 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3223
3224 static constexpr llvm::Intrinsic::ID IDs[] = {
3225 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3226 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3227 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3228 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3229 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3230 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3231 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3232 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3233 llvm::Intrinsic::
3234 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3235 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3236
3237 // Tidy-up the Intrinsic Args
3238 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3239 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3240 if (needCast)
3241 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3242
3243 // We have the most basic mbarrier.arrive supported on sm_80.
3244 // It supports: Space=cta, scope=cta, No relaxed, No explicit count.
3245 // So, only for this combination use the legacy intrinsic.
3246 bool hasCount = static_cast<bool>(thisOp.getCount());
3247 if (!hasCount &&
3248 (id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3249 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3250
3251 // When count is not explicitly specified, the default is 1.
3252 llvm::LLVMContext &ctx = mt.getLLVMContext();
3253 llvm::Value *count =
3254 hasCount ? mt.lookupValue(thisOp.getCount())
3255 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3256 return {id, {mbar, count}};
3257}
3258
3259mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
3260 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3261 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3262
3263 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3264 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3265 // bit-0: Space
3266 // bit-1: Scope
3267 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3268
3269 static constexpr llvm::Intrinsic::ID IDs[] = {
3270 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3271 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3272 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3273 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3274 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3275 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3276 llvm::Intrinsic::
3277 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3278 llvm::Intrinsic::
3279 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3280 llvm::Intrinsic::
3281 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3282 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3283
3284 // Tidy-up the Intrinsic Args
3285 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3286 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3287 if (needCast)
3288 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3289
3290 // When count is not explicitly specified, the default is 1.
3291 llvm::LLVMContext &ctx = mt.getLLVMContext();
3292 bool hasCount = static_cast<bool>(thisOp.getCount());
3293 llvm::Value *count =
3294 hasCount ? mt.lookupValue(thisOp.getCount())
3295 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3296
3297 return {id, {mbar, count}};
3298}
3299
3300bool MBarrierArriveExpectTxOp::getAsmValues(
3301 RewriterBase &rewriter,
3302 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3303 &asmValues) {
3304 // Add all the operands but not the attrs to the asmValues list.
3305 // The attrs here are used to generate the right variants for
3306 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3307 for (auto val : getOperands())
3308 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3309
3310 return false;
3311}
3312
3313mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
3314 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3315 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3316
3317 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3318 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3319 // bit-0: Space
3320 // bit-1: Scope
3321 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3322
3323 // clang-format off
3324 static constexpr llvm::Intrinsic::ID IDs[] = {
3325 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3326 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3327 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3328 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3329 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3330 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3331 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3332 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3333 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3334 // clang-format on
3335 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3336
3337 // Tidy-up the Intrinsic Args
3338 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3339 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3340 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3341 if (needCast)
3342 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3343
3344 return {id, {mbar, txcount}};
3345}
3346
3347mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
3348 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3349 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3350
3351 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3352 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3353 // bit-0: Space
3354 // bit-1: Scope
3355 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3356
3357 // clang-format off
3358 static constexpr llvm::Intrinsic::ID IDs[] = {
3359 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3360 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3361 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3362 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3363 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3364 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3365 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3366 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3367 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3368 // clang-format on
3369 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3370
3371 // Tidy-up the Intrinsic Args
3372 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3373 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3374 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3375 if (needCast)
3376 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3377
3378 return {id, {mbar, txcount}};
3379}
3380
3381mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
3382 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3383 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3384 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3385 llvm::Intrinsic::ID id =
3386 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3387 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3388 // Fill the Intrinsic Args
3390 args.push_back(mt.lookupValue(thisOp.getAddr()));
3391 args.push_back(mt.lookupValue(thisOp.getCount()));
3392
3393 return {id, std::move(args)};
3394}
3395
3396mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
3397 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3398 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3399 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3400 llvm::Intrinsic::ID id =
3401 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3402 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3403 // Fill the Intrinsic Args
3405 args.push_back(mt.lookupValue(thisOp.getAddr()));
3406 args.push_back(mt.lookupValue(thisOp.getCount()));
3407
3408 return {id, std::move(args)};
3409}
3410
3411mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
3412 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3413 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3414 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3415 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3416 // bit-0: isPhaseParity
3417 // bit-1: Scope
3418 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3419
3420 // clang-format off
3421 static constexpr llvm::Intrinsic::ID IDs[] = {
3422 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3423 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3424 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3425 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3426 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3427 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3428 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3429 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3430 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3431 // clang-format on
3432 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3433
3434 // Tidy-up the Intrinsic Args
3435 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3436 llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
3437 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3438 if (needCast)
3439 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3440
3441 return {id, {mbar, input}};
3442}
3443
3444mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
3445 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3446 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3447 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3448 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3449 bool hasTicks = static_cast<bool>(thisOp.getTicks());
3450 // bit-0: isPhaseParity
3451 // bit-1: Scope
3452 // bit-2: hasTicks
3453 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3454 (isPhaseParity ? 1 : 0);
3455
3456 // clang-format off
3457 static constexpr llvm::Intrinsic::ID IDs[] = {
3458 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3459 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3460 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3461 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3462 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3463 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3464 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3465 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3466 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3467 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3468 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3469 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3470 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3471 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3472 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3473 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3474 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3475 // clang-format on
3476 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3477
3478 // Tidy-up the mbarrier pointer
3479 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3480 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3481 if (needCast)
3482 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3483
3484 // Fill the Intrinsic Args
3486 args.push_back(mbar);
3487 args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
3488 if (hasTicks)
3489 args.push_back(mt.lookupValue(thisOp.getTicks()));
3490
3491 return {id, std::move(args)};
3492}
3493
3494mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
3495 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3496 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3497 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3498
3499 llvm::Intrinsic::ID id;
3500 if (thisOp.getNoinc()) {
3501 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3502 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3503 } else {
3504 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3505 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3506 }
3507
3508 return {id, {mt.lookupValue(thisOp.getAddr())}};
3509}
3510
3511#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3512 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3513
3514#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3515 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
3516
3517llvm::Intrinsic::ID
3518CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3520 llvm::Intrinsic::ID id;
3521
3522 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
3523 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
3524 switch (cpAsyncOp.getSize()) {
3525 case 4:
3526 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
3527 break;
3528 case 8:
3529 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
3530 break;
3531 case 16:
3532 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
3533 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
3534 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
3535 break;
3536 default:
3537 llvm_unreachable("Invalid copy size in CpAsyncOp.");
3538 }
3539
3540 // Fill the Intrinsic Args
3541 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
3542 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
3543 if (hasCpSize)
3544 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
3545
3546 return id;
3547}
3548
3549mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
3550 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3551 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
3553 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
3554
3555 // Fill the Intrinsic Args
3556 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3557 args.push_back(mt.lookupValue(thisOp.getSize()));
3558
3559 mlir::Value cacheHint = thisOp.getL2CacheHint();
3560 const bool hasCacheHint = static_cast<bool>(cacheHint);
3561 llvm::Value *i64Unused =
3562 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3563 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3564 args.push_back(builder.getInt1(hasCacheHint));
3565
3566 return {id, std::move(args)};
3567}
3568
3569mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3570 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3571 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
3573
3574 // Fill the Intrinsic Args: dst, mbar, src, size.
3575 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3576 args.push_back(mt.lookupValue(thisOp.getMbar()));
3577 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3578 args.push_back(mt.lookupValue(thisOp.getSize()));
3579
3580 // Multicast mask for shared::cluster only, if available.
3581 mlir::Value multicastMask = thisOp.getMulticastMask();
3582 const bool hasMulticastMask = static_cast<bool>(multicastMask);
3583 const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
3584 if (!isSharedCTA) {
3585 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
3586 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
3587 : i16Unused);
3588 }
3589
3590 // Cache hint, if available.
3591 mlir::Value cacheHint = thisOp.getL2CacheHint();
3592 const bool hasCacheHint = static_cast<bool>(cacheHint);
3593 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
3594 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3595
3596 // Flag arguments for multicast and cachehint.
3597 if (!isSharedCTA)
3598 args.push_back(builder.getInt1(hasMulticastMask));
3599 args.push_back(builder.getInt1(hasCacheHint));
3600
3601 llvm::Intrinsic::ID id =
3602 isSharedCTA
3603 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
3604 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
3605
3606 return {id, std::move(args)};
3607}
3608
3609mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3610 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3611 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
3613 llvm::Intrinsic::ID id =
3614 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
3615
3616 // Fill the Intrinsic Args
3617 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3618 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3619 args.push_back(mt.lookupValue(thisOp.getSize()));
3620
3621 mlir::Value cacheHint = thisOp.getL2CacheHint();
3622 const bool hasCacheHint = static_cast<bool>(cacheHint);
3623 llvm::Value *i64Unused =
3624 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3625 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3626 args.push_back(builder.getInt1(hasCacheHint));
3627
3628 // Choose the bytemask variant
3629 if (mlir::Value byteMask = thisOp.getByteMask()) {
3630 args.push_back(mt.lookupValue(byteMask));
3631 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
3632 }
3633
3634 return {id, std::move(args)};
3635}
3636
3637bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
3638 RewriterBase &rewriter,
3639 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3640 &asmValues) {
3641 // Add all the operands but not the attrs to the asmValues list.
3642 // The attrs here are used to generate the right variants for
3643 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3644 for (auto val : getOperands())
3645 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3646
3647 return false;
3648}
3649
3651CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3652 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3653 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
3654 const bool isCTAOnly = thisOp.getIsCTAOnly();
3656
3657 // Fill the Intrinsic Args
3658 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3659 args.push_back(mt.lookupValue(thisOp.getMbar()));
3660 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3661
3662 // Coordinates and im2col-offsets
3663 for (mlir::Value v : thisOp.getCoordinates())
3664 args.push_back(mt.lookupValue(v));
3665 for (mlir::Value v : thisOp.getIm2colOffsets())
3666 args.push_back(mt.lookupValue(v));
3667
3668 // MulticastMask, if available
3669 mlir::Value mcMask = thisOp.getMulticastMask();
3670 const bool hasMC = static_cast<bool>(mcMask);
3671 llvm::Value *i16Zero =
3672 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
3673
3674 // CacheHint, if available
3675 mlir::Value cacheHint = thisOp.getL2CacheHint();
3676 const bool hasCacheHint = static_cast<bool>(cacheHint);
3677 llvm::Value *i64Zero =
3678 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3679
3680 // Flag argument CTAGroup
3681 // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
3682 // Hence, the +1 to getGroup().
3683 const int32_t val =
3684 thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
3685 llvm::Value *cg =
3686 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
3687
3688 if (!isCTAOnly) {
3689 // For shared::cluster, all the arguments that we build are applicable.
3690 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
3691 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
3692 args.push_back(builder.getInt1(hasMC));
3693 args.push_back(builder.getInt1(hasCacheHint));
3694 args.push_back(cg);
3695 } else {
3696 // For shared::cta, only cache-hint is applicable.
3697 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
3698 args.push_back(builder.getInt1(hasCacheHint));
3699 }
3700
3701 constexpr size_t numDims = 5; // 1D to 5D
3702 constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
3703 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
3704 using TableTy = std::array<rowTy, numModes>;
3705 static constexpr TableTy IDTable{
3706 {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
3707 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
3708 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
3709 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
3710 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
3712 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
3713 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
3714 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
3716 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
3717 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
3718 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
3720 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
3721 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
3722 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
3724 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
3725
3726 static constexpr TableTy IDTableCTA{
3727 {{notIntrinsic,
3728 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
3729 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
3730 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
3731 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
3732 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
3734 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
3735 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
3736 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
3738 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
3739 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
3740 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
3742 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
3743 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
3744 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
3746 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
3747
3748 static_assert(
3749 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
3750 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
3751 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
3752 size_t mode = static_cast<size_t>(thisOp.getMode());
3753 size_t dim = thisOp.getCoordinates().size();
3754 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
3755 assert(id != notIntrinsic &&
3756 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
3757
3758 return {id, std::move(args)};
3759}
3760
3761mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
3762 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3763 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
3765
3766 // Fill the Intrinsic Args
3767 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3768
3769 for (auto v : thisOp.getCoordinates())
3770 args.push_back(mt.lookupValue(v));
3771 for (auto v : thisOp.getIm2colOffsets())
3772 args.push_back(mt.lookupValue(v));
3773
3774 mlir::Value cacheHint = thisOp.getL2CacheHint();
3775 const bool hasCacheHint = static_cast<bool>(cacheHint);
3776 llvm::Value *i64Unused =
3777 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3778 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3779 args.push_back(builder.getInt1(hasCacheHint));
3780
3781 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3782 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3783 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
3784 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
3785 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
3786 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
3787 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
3788 {NI, NI, NI,
3789 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
3790 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
3791 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
3792 {NI, NI, NI,
3793 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
3794 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
3795 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
3796 {NI, NI, NI,
3797 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
3798 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
3799 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
3800 {NI, NI, NI, NI, NI,
3801 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
3802
3803 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
3804 "TMALoadModes must match number of rows in IDTable");
3805 size_t mode = static_cast<size_t>(thisOp.getMode());
3806 size_t dim = thisOp.getCoordinates().size();
3807 llvm::Intrinsic::ID id = IDTable[mode][dim];
3808 if (id == llvm::Intrinsic::not_intrinsic)
3809 llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
3810
3811 return {id, std::move(args)};
3812}
3813
3815CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3816 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3817 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
3819
3820 // Fill the Intrinsic Args
3821 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3822 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3823
3824 for (auto v : thisOp.getCoordinates())
3825 args.push_back(mt.lookupValue(v));
3826
3827 mlir::Value cacheHint = thisOp.getL2CacheHint();
3828 const bool hasCacheHint = static_cast<bool>(cacheHint);
3829 llvm::Value *i64Unused =
3830 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3831 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3832 args.push_back(builder.getInt1(hasCacheHint));
3833
3834 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3835 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3836 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
3837 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
3838 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
3839 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
3840 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
3841 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
3842 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
3843 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
3844 {NI, NI, NI, NI, NI,
3845 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
3846
3847 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
3848 "TMAStoreModes must match number of rows in IDTable");
3849 size_t mode = static_cast<size_t>(thisOp.getMode());
3850 size_t dim = thisOp.getCoordinates().size();
3851 llvm::Intrinsic::ID id = IDTable[mode][dim];
3852 if (id == llvm::Intrinsic::not_intrinsic)
3853 llvm_unreachable(
3854 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
3855
3856 return {id, std::move(args)};
3857}
3858
3859NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
3860 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3861 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
3862 llvm::LLVMContext &ctx = mt.getLLVMContext();
3863
3865
3866 // Arguments to the intrinsic:
3867 // shared_mem_ptr, tmaDesc, tensorDims
3868 // cache_hint(if applicable) and flag(boolean)
3869 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3870 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3871
3872 for (Value v : thisOp.getCoordinates())
3873 args.push_back(mt.lookupValue(v));
3874
3875 mlir::Value cacheHint = thisOp.getL2CacheHint();
3876 const bool hasCacheHint = static_cast<bool>(cacheHint);
3877 llvm::Value *i64ZeroValue =
3878 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
3879 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
3880 args.push_back(builder.getInt1(hasCacheHint));
3881
3882 const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
3883
3884 constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
3885 constexpr unsigned numLayouts = 2; // TILE, IM2COL
3886 constexpr unsigned maxDim = 5; // 1D to 5D
3887 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
3888 using layoutTable = std::array<row, numLayouts>;
3889 using fullTable = std::array<layoutTable, numRedKinds>;
3890 static constexpr fullTable IDTable{
3891 {// RedTy::ADD
3892 {{{{notIntrinsic,
3893 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
3894 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
3895 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
3896 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
3897 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
3899 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
3900 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
3901 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
3902 // RedTy::MIN
3903 {{{{notIntrinsic,
3904 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
3905 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
3906 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
3907 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
3908 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
3910 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
3911 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
3912 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
3913 // RedTy::MAX
3914 {{{{notIntrinsic,
3915 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
3916 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
3917 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
3918 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
3919 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
3921 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
3922 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
3923 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
3924 // RedTy::INC
3925 {{{{notIntrinsic,
3926 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
3927 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
3928 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
3929 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
3930 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
3932 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
3933 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
3934 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
3935 // RedTy::DEC
3936 {{{{notIntrinsic,
3937 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
3938 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
3939 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
3940 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
3941 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
3943 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
3944 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
3945 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
3946 // RedTy::AND
3947 {{{{notIntrinsic,
3948 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
3949 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
3950 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
3951 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
3952 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
3954 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
3955 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
3956 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
3957 // RedTy::OR
3958 {{{{notIntrinsic,
3959 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
3960 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
3961 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
3962 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
3963 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
3965 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
3966 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
3967 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
3968 // RedTy::XOR
3969 {{{{notIntrinsic,
3970 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
3971 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
3972 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
3973 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
3974 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
3976 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
3977 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
3978 llvm::Intrinsic::
3979 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
3980
3981 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
3982 "TMAReduxKinds must match number of rows in IDTable");
3983
3984 size_t redKind = static_cast<size_t>(thisOp.getRedKind());
3985 size_t mode = static_cast<size_t>(thisOp.getMode());
3986 size_t dim = thisOp.getCoordinates().size();
3987
3988 assert(redKind < IDTable.size() &&
3989 "Invalid redKind for CpAsyncBulkTensorReduceOp");
3990 assert(mode < IDTable[redKind].size() &&
3991 "Invalid mode for CpAsyncBulkTensorReduceOp");
3992 assert(dim < IDTable[redKind][mode].size() &&
3993 "Invalid dim for CpAsyncBulkTensorReduceOp");
3994
3995 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
3996
3997 assert(intrinsicID != notIntrinsic &&
3998 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
3999
4000 return {intrinsicID, std::move(args)};
4001}
4002
4003#define _none
4004
4005#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4006 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4007 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4008
4009#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4010 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4011 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4012
4013llvm::Intrinsic::ID
4014ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4015 NVVM::SaturationMode sat, bool hasRelu) {
4016 using RndMode = NVVM::FPRoundingMode;
4017 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4018 switch (rnd) {
4019 case RndMode::RN:
4020 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
4021 case RndMode::RZ:
4022 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
4023 case RndMode::RNA:
4024 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
4025 default:
4026 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
4027 }
4028}
4029
4031ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4033 llvm::IRBuilderBase &builder) {
4035 args.push_back(mt.lookupValue(op.getA()));
4036 args.push_back(mt.lookupValue(op.getB()));
4037
4038 bool hasRelu = op.getRelu();
4039
4040 llvm::Intrinsic::ID intId =
4041 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4042 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4043
4044 return {intId, std::move(args)};
4045}
4046
4047#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4048 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4049 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4050
4051llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4052 bool hasRelu) {
4054 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4055 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
4056 })
4057 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4058 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
4059 })
4060 .Default([](mlir::Type) {
4061 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
4062 return llvm::Intrinsic::not_intrinsic;
4063 });
4064}
4065
4066#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4067 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4068 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4069
4070#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4071 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4072 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4073
4074llvm::Intrinsic::ID
4075ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4076 NVVM::SaturationMode sat, bool hasRelu) {
4077 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4078 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4079 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4080
4082 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4083 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
4084 })
4085 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4086 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
4087 })
4088 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
4089 if (hasRoundingModeRZ)
4090 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
4091 else if (hasRoundingModeRP)
4092 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
4093
4094 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4095 })
4096 .Default([](mlir::Type) {
4097 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4098 return llvm::Intrinsic::not_intrinsic;
4099 });
4100}
4101
4102#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4103 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4104 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4105
4106llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4107 bool hasRelu) {
4109 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4110 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
4111 })
4112 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4113 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
4114 })
4115 .Default([](mlir::Type) {
4116 llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
4117 return llvm::Intrinsic::not_intrinsic;
4118 });
4119}
4120
4121#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
4122 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
4123 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
4124
4125llvm::Intrinsic::ID
4126ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4127 NVVM::SaturationMode sat) {
4128 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4129 switch (rnd) {
4130 case NVVM::FPRoundingMode::RZ:
4131 return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
4132 case NVVM::FPRoundingMode::RP:
4133 return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
4134 default:
4135 llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
4136 }
4137}
4138
4139NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
4140 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4141 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4142
4143 bool hasRelu = curOp.getRelu();
4144
4145 llvm::Intrinsic::ID intId =
4147 .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
4148 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4149 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4150 })
4151 .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
4152 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4153 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4154 })
4155 .Default([](mlir::Type type) {
4156 llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
4157 return llvm::Intrinsic::not_intrinsic;
4158 });
4159
4160 llvm::Value *packedI16 =
4161 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4162 llvm::Type::getInt16Ty(builder.getContext()));
4163
4164 return {intId, {packedI16}};
4165}
4166
4167NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
4168 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4169 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4170
4171 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4172 llvm::Value *packedI16 =
4173 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4174 llvm::Type::getInt16Ty(builder.getContext()));
4175
4176 return {intId, {packedI16}};
4177}
4178
4179NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
4180 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4181 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4182
4183 bool hasRelu = curOp.getRelu();
4184
4185 llvm::Intrinsic::ID intId =
4187 .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
4188 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4189 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4190 })
4191 .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
4192 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4193 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4194 })
4195 .Default([](mlir::Type type) {
4196 llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
4197 return llvm::Intrinsic::not_intrinsic;
4198 });
4199
4200 llvm::Value *packedI16 =
4201 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4202 llvm::Type::getInt16Ty(builder.getContext()));
4203
4204 return {intId, {packedI16}};
4205}
4206
4207NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
4208 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4209 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4210
4211 bool hasRelu = curOp.getRelu();
4212
4213 llvm::Intrinsic::ID intId =
4215 .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
4216 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4217 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4218 })
4219 .Default([](mlir::Type type) {
4220 llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
4221 return llvm::Intrinsic::not_intrinsic;
4222 });
4223
4224 llvm::Value *extendedI16 =
4225 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
4226 llvm::Type::getInt16Ty(builder.getContext()));
4227
4228 return {intId, {extendedI16}};
4229}
4230
4231llvm::Intrinsic::ID
4232Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
4235 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4236 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4237 .getAddressSpace();
4238 bool isShared = as == NVVMMemorySpace::Shared;
4239 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4240
4241 llvm::Intrinsic::ID id;
4242 if (isShared) {
4243 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4244 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4245 } else {
4246 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4247 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4248 }
4249
4250 // Fill the Intrinsic Args
4251 args.push_back(mt.lookupValue(curOp.getAddr()));
4252 args.push_back(mt.lookupValue(curOp.getNCols()));
4253
4254 return id;
4255}
4256
4257llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4260 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4261 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4262 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4263 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4264
4265 // Fill the Intrinsic Args
4266 args.push_back(mt.lookupValue(curOp.getTaddr()));
4267 args.push_back(mt.lookupValue(curOp.getNCols()));
4268
4269 return id;
4270}
4271
4272#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4273 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4274 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4275
4276#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4277 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4278 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4279
4280llvm::Intrinsic::ID
4281Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
4284 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4285 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4286 .getAddressSpace();
4287 bool isShared = as == NVVMMemorySpace::Shared;
4288 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
4289 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4290
4291 llvm::Intrinsic::ID id =
4292 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
4293 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
4294
4295 // Fill the Intrinsic Args
4296 args.push_back(mt.lookupValue(curOp.getAddr()));
4297 if (hasMulticast)
4298 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
4299
4300 return id;
4301}
4302
4303#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4304 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4305
4306#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4307 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4308 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4309
4310#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4311 [&]() -> auto { \
4312 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4313 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4314 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4315 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4316 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4317 }()
4318
4320ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4322 llvm::IRBuilderBase &builder) {
4323 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4324 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4325 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4326 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4327 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4328 };
4329 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4330 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4331 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4332 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4333 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4334 };
4335 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4336 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4337 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4338 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4339 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4340 };
4341
4342 unsigned hasRelu = op.getRelu() ? 1 : 0;
4343 unsigned hasSatFinite =
4344 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4345 // idx: bit-0 - relu
4346 // bit-1 - satfinite
4347 unsigned idx = (hasSatFinite << 1) | hasRelu;
4348
4350 args.push_back(mt.lookupValue(op.getSrcHi()));
4351 args.push_back(mt.lookupValue(op.getSrcLo()));
4352 if (op.getRandomBits())
4353 args.push_back(mt.lookupValue(op.getRandomBits()));
4354
4355 switch (op.getRnd()) {
4356 case FPRoundingMode::RN:
4357 return {rndRNIds[idx], std::move(args)};
4358 case FPRoundingMode::RZ:
4359 return {rndRZIds[idx], std::move(args)};
4360 case FPRoundingMode::RS:
4361 return {rndRSIds[idx], std::move(args)};
4362 default:
4363 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
4364 }
4365}
4366
4368ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
4370 llvm::IRBuilderBase &builder) {
4371 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4372 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
4373 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
4374 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
4375 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
4376 };
4377 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4378 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
4379 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
4380 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
4381 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
4382 };
4383 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4384 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
4385 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
4386 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
4387 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
4388 };
4389
4390 unsigned hasRelu = op.getRelu() ? 1 : 0;
4391 unsigned hasSatFinite =
4392 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4393 // idx: bit-0 - relu
4394 // bit-1 - satfinite
4395 unsigned idx = (hasSatFinite << 1) | hasRelu;
4396
4398 args.push_back(mt.lookupValue(op.getSrcHi()));
4399 args.push_back(mt.lookupValue(op.getSrcLo()));
4400 if (op.getRandomBits())
4401 args.push_back(mt.lookupValue(op.getRandomBits()));
4402
4403 switch (op.getRnd()) {
4404 case FPRoundingMode::RN:
4405 return {rndRNIds[idx], std::move(args)};
4406 case FPRoundingMode::RZ:
4407 return {rndRZIds[idx], std::move(args)};
4408 case FPRoundingMode::RS:
4409 return {rndRSIds[idx], std::move(args)};
4410 default:
4411 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
4412 }
4413}
4414
4415llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
4416 mlir::Type dstTy = getDstTy();
4417 bool hasRelu = getRelu();
4418
4420 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4421 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
4422 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
4423 })
4424 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4425 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
4426 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
4427 })
4428 .Default([](mlir::Type) {
4429 llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
4430 return llvm::Intrinsic::not_intrinsic;
4431 });
4432}
4433
4434llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
4435 mlir::Type dstTy = getDstTy();
4436 bool hasRelu = getRelu();
4437
4439 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4440 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
4441 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
4442 })
4443 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4444 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
4445 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
4446 })
4447 .Default([](mlir::Type) {
4448 llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
4449 return llvm::Intrinsic::not_intrinsic;
4450 });
4451}
4452
4453llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
4454 mlir::Type dstTy = getDstTy();
4455 bool hasRelu = getRelu();
4456
4458 .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
4459 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
4460 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
4461 })
4462 .Default([](mlir::Type) {
4463 llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
4464 return llvm::Intrinsic::not_intrinsic;
4465 });
4466}
4467
4468llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
4469 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
4470 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
4471 auto srcFmt = curOp.getSrcFormat();
4472 auto mc = curOp.getMulticast();
4473
4474 switch (curOp.getShape()) {
4475 case Tcgen05CpShape::SHAPE_128x256b:
4476 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
4477 case Tcgen05CpShape::SHAPE_128x128b:
4478 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
4479 case Tcgen05CpShape::SHAPE_4x256b:
4480 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
4481 case Tcgen05CpShape::SHAPE_32x128b:
4482 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
4483 case Tcgen05CpShape::SHAPE_64x128b:
4484 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
4485 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
4486 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
4487 }
4488 llvm_unreachable("Invalid shape in tcgen05 cp Op");
4489}
4490
4491// Returns the valid vector length for a given shape and vector length, the
4492// function models the table mentioned in the tcgen05.{ld, st} Op description
4493static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
4494 unsigned vecLen) {
4495 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
4496 return vecLen >= 2;
4497 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
4498 return vecLen >= 4;
4499 return true;
4500}
4501
4502LogicalResult Tcgen05LdOp::verify() {
4503 LogicalResult result = success();
4504 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4505 result = emitError("shape 16x32bx2 requires offset argument");
4506
4507 if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
4508 result = emitError("offset argument is only supported for shape 16x32bx2");
4509
4510 auto resTy = getRes().getType();
4511 unsigned resLen = isa<VectorType>(resTy)
4512 ? llvm::cast<VectorType>(resTy).getNumElements()
4513 : 1;
4514 if (!isValidVectorLength(getShape(), resLen))
4515 result = emitError(llvm::formatv("invalid result type length {0} for shape "
4516 "{1} in tcgen05.ld Op",
4517 resLen, stringifyEnum(getShape())));
4518
4519 return result;
4520}
4521
4522LogicalResult Tcgen05StOp::verify() {
4523 LogicalResult result = success();
4524 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4525 result = emitError("shape 16x32bx2 requires offset argument");
4526
4527 auto valTy = getVal().getType();
4528 unsigned valLen = isa<VectorType>(valTy)
4529 ? llvm::cast<VectorType>(valTy).getNumElements()
4530 : 1;
4531 if (!isValidVectorLength(getShape(), valLen))
4532 result = emitError(llvm::formatv("invalid input length {0} for shape "
4533 "{1} in tcgen05.st Op",
4534 valLen, stringifyEnum(getShape())));
4535
4536 return result;
4537}
4538
4539/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
4540/// have ConstantRangeAttr.
4543 SetIntRangeFn setResultRanges) {
4544 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
4545 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
4546 rangeAttr.getLower(), rangeAttr.getUpper()});
4547 } else {
4548 setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
4549 }
4550}
4551
4552/// Verify the range attribute satisfies LLVM ConstantRange constructor
4553/// requirements for NVVM SpecialRangeableRegisterOp.
4554static LogicalResult
4556 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
4557 if (!rangeAttr)
4558 return success();
4559
4560 const llvm::APInt &lower = rangeAttr->getLower();
4561 const llvm::APInt &upper = rangeAttr->getUpper();
4562
4563 // Check LLVM ConstantRange constructor condition
4564 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
4565 unsigned bitWidth = lower.getBitWidth();
4566 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
4567 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
4568 return op->emitOpError(
4569 "invalid range attribute: Lower == Upper, but they aren't min (")
4570 << llvm::toString(minVal, 10, false) << ") or max ("
4571 << llvm::toString(maxVal, 10, false)
4572 << ") value! This is an invalid constant range.";
4573 }
4574
4575 return success();
4576}
4577
4578static llvm::Value *getAsPackedI32(llvm::Value *arg,
4579 llvm::IRBuilderBase &builder) {
4580 return builder.CreateBitCast(arg,
4581 llvm::Type::getInt32Ty(builder.getContext()));
4582}
4583
4584NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
4585 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4586 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
4587
4589 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
4590 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
4591 args.push_back(mt.lookupValue(curOp.getC()));
4592
4593 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4594 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4595 unsigned type = (isASigned << 1) | isBSigned;
4596 const llvm::Intrinsic::ID ids[] = {
4597 llvm::Intrinsic::nvvm_idp4a_u_u,
4598 llvm::Intrinsic::nvvm_idp4a_u_s,
4599 llvm::Intrinsic::nvvm_idp4a_s_u,
4600 llvm::Intrinsic::nvvm_idp4a_s_s,
4601 };
4602 return {ids[type], args};
4603}
4604
4605NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
4606 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4607 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
4608
4610 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
4611 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
4612 args.push_back(builder.getInt1(curOp.getBHi()));
4613 args.push_back(mt.lookupValue(curOp.getC()));
4614
4615 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4616 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4617 unsigned type = (isASigned << 1) | isBSigned;
4618 const llvm::Intrinsic::ID ids[] = {
4619 llvm::Intrinsic::nvvm_idp2a_u_u,
4620 llvm::Intrinsic::nvvm_idp2a_u_s,
4621 llvm::Intrinsic::nvvm_idp2a_s_u,
4622 llvm::Intrinsic::nvvm_idp2a_s_s,
4623 };
4624 return {ids[type], args};
4625}
4626
4627static llvm::Value *getParamCastedAddr(llvm::Value *addr,
4628 llvm::IRBuilderBase &builder) {
4629 return builder.CreateAddrSpaceCast(
4630 addr,
4631 llvm::PointerType::get(builder.getContext(),
4632 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
4633}
4634
4636PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
4638 llvm::IRBuilderBase &builder) {
4639 using MemSpace = NVVM::NVVMMemorySpace;
4640 using CacheLevel = NVVM::PrefetchCacheLevel;
4641
4642 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
4643 std::optional<NVVM::CacheEvictionPriority> evictPriority =
4644 op.getEvictPriority();
4645 unsigned addressSpace =
4646 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
4647 .getAddressSpace();
4648
4650 llvm::Value *addr = mt.lookupValue(op.getAddr());
4651 args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
4652 : addr);
4653
4654 if (op.getTensormap())
4655 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
4656
4657 assert(cacheLevel && "expected cache level for non-tensormap prefetch");
4658
4659 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
4660 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
4661
4662 if (evictPriority && *cacheLevel == CacheLevel::L2) {
4663 switch (*evictPriority) {
4664 case NVVM::CacheEvictionPriority::EvictLast:
4665 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
4666 case NVVM::CacheEvictionPriority::EvictNormal:
4667 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
4668 default:
4669 llvm_unreachable("Invalid cache eviction priority");
4670 }
4671 }
4672
4673 switch (static_cast<MemSpace>(addressSpace)) {
4674 case MemSpace::Generic:
4675 return *cacheLevel == CacheLevel::L1
4676 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
4677 : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
4678 case MemSpace::Global:
4679 return *cacheLevel == CacheLevel::L1
4681 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
4682 : NVVM::IDArgPair(
4683 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
4684 case MemSpace::Local:
4685 return *cacheLevel == CacheLevel::L1
4687 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
4688 : NVVM::IDArgPair(
4689 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
4690 default:
4691 llvm_unreachable("Invalid pointer address space");
4692 }
4693}
4694
4695bool NVVM::InlinePtxOp::getAsmValues(
4696 RewriterBase &rewriter,
4697 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
4698 &asmValues) {
4699 for (auto arg : getReadWriteArgs())
4700 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
4701 for (auto arg : getResults())
4702 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
4703 for (auto arg : getReadOnlyArgs())
4704 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
4705 if (getPredicate())
4706 asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
4707 return false; // No manual mapping needed
4708}
4709
4710NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
4711 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4712 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
4714 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
4715 args.push_back(mt.lookupValue(curOp.getMbarrier()));
4716
4717 llvm::Intrinsic::ID intrinsicID =
4718 curOp.getMulticast()
4719 ? llvm::Intrinsic::
4720 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
4721 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
4722
4723 return {intrinsicID, args};
4724}
4725
4726NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
4727 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4728 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
4730 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
4731
4732 llvm::Intrinsic::ID intrinsicID;
4733
4734 switch (curOp.getQueryType()) {
4735 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
4736 intrinsicID =
4737 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
4738 break;
4739 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
4740 intrinsicID = llvm::Intrinsic::
4741 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
4742 break;
4743 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
4744 intrinsicID = llvm::Intrinsic::
4745 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
4746 break;
4747 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
4748 intrinsicID = llvm::Intrinsic::
4749 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
4750 break;
4751 }
4752 return {intrinsicID, args};
4753}
4754
4756PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4757 llvm::IRBuilderBase &builder) {
4758 auto thisOp = cast<NVVM::PermuteOp>(op);
4759 NVVM::PermuteMode mode = thisOp.getMode();
4760
4761 static constexpr llvm::Intrinsic::ID IDs[] = {
4762 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
4763 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
4764 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
4765 llvm::Intrinsic::nvvm_prmt_rc16};
4766
4767 unsigned modeIndex = static_cast<unsigned>(mode);
4769 args.push_back(mt.lookupValue(thisOp.getLo()));
4770
4771 // Only first 3 modes (Default, f4e, b4e) need the hi operand.
4772 if (modeIndex < 3)
4773 args.push_back(mt.lookupValue(thisOp.getHi()));
4774
4775 args.push_back(mt.lookupValue(thisOp.getSelector()));
4776
4777 return {IDs[modeIndex], args};
4778}
4779
4780//===----------------------------------------------------------------------===//
4781// NVVM tcgen05.mma functions
4782//===----------------------------------------------------------------------===//
4783
4785Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4786 llvm::IRBuilderBase &builder) {
4787
4788 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
4790
4791 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
4792
4793 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
4794 const bool isATensor = isa<llvm::PointerType>(A->getType());
4795 args.push_back(A);
4796
4797 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
4798 args.push_back(mt.lookupValue(thisOp.getIdesc()));
4799 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
4800
4801 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
4802 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
4803 using IsATensorArray = std::array<CtaGroupArray, 2>;
4804 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
4805 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
4806
4807 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
4808 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
4809 { // without diable output lane
4810 {{// without scale input D
4811 {{
4812 // shared
4813 {{// cg1
4814 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
4815 // cg2
4816 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
4817 {{// tensor
4818 {
4819 // cg1
4820 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
4821 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
4822 },
4823 {
4824 // cg2
4825 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
4826 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
4827 }}},
4828 }},
4829 // with scale input D
4830 {{ // shared
4831 {{// cg1
4832 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
4833 // cg2
4834 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
4835 {{// tensor
4836 {
4837 // cg1
4838 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
4839 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
4840 },
4841 {
4842 // cg2
4843 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
4844 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
4845 }}}}}}},
4846 // with disable output lane
4847 {{ // without scale input D
4848 {{ // shared
4849 {{// cg1
4850 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
4851 notIntrinsic},
4852 // cg2
4853 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
4854 notIntrinsic}}},
4855 {{// cg1
4856 {
4857 llvm::Intrinsic::
4858 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
4859 llvm::Intrinsic::
4860 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
4861 },
4862 // cg2
4863 {
4864 llvm::Intrinsic::
4865 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
4866 llvm::Intrinsic::
4867 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
4868 }}}}},
4869 // with scale input D
4870 {{ // shared
4871 {{// cg1
4872 {llvm::Intrinsic::
4873 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
4874 notIntrinsic},
4875 // cg2
4876 {llvm::Intrinsic::
4877 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
4878 notIntrinsic}}},
4879 // tensor
4880 {{// cg1
4881 {llvm::Intrinsic::
4882 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
4883 llvm::Intrinsic::
4884 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
4885 // cg2
4886 {
4887 llvm::Intrinsic::
4888 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
4889 llvm::Intrinsic::
4890 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
4891 }}}}}}}}};
4892
4893 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
4894 bool hasScaleInputD = ScaleInputD != nullptr;
4895
4896 llvm::Value *DisableOutputLane =
4897 mt.lookupValue(thisOp.getDisableOutputLane());
4898 bool hasDisableOutputLane = DisableOutputLane != nullptr;
4899
4900 const unsigned ctaGroup =
4901 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
4902
4903 llvm::Intrinsic::ID ID =
4904 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
4905 [ctaGroup - 1][thisOp.getAShift()];
4906
4907 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
4908
4909 if (hasScaleInputD)
4910 args.push_back(ScaleInputD);
4911
4912 if (hasDisableOutputLane)
4913 args.push_back(DisableOutputLane);
4914
4915 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
4916
4917 if (!hasDisableOutputLane)
4918 args.push_back(builder.getInt32(ctaGroup));
4919
4920 args.push_back(
4921 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
4922
4923 return {ID, args};
4924}
4925
4926static LogicalResult
4927verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
4928 NVVM::CTAGroupKind ctaGroup, bool hasAShift,
4929 NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
4930
4931 if (disableOutputLane) {
4932 mlir::VectorType disableOutputLaneType =
4933 cast<mlir::VectorType>(disableOutputLane.getType());
4934 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
4935 disableOutputLaneType.getNumElements() != 4) ||
4936 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
4937 disableOutputLaneType.getNumElements() != 8))
4938 return emitError(loc) << "Disable Output Lane of length "
4939 << disableOutputLaneType.getNumElements()
4940 << " is incompatible with CtaGroupAttr";
4941 }
4942
4943 if (hasAShift && !isATensor)
4944 return emitError(
4945 loc, "A-shift can be applied only when matrix A is in tensor memory");
4946
4947 if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
4948 collectorOp == Tcgen05MMACollectorOp::USE))
4949 return emitError(
4950 loc, "Cannot use collector buffer operation fill or use with ashift");
4951
4952 return success();
4953}
4954
4955LogicalResult Tcgen05MMAOp::verify() {
4956 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
4957 getDisableOutputLane(), getCtaGroup(), getAShift(),
4958 getCollectorOp(), getLoc());
4959}
4960
4961//===----------------------------------------------------------------------===//
4962// NVVM tcgen05.mma.sp functions
4963//===----------------------------------------------------------------------===//
4964
4965mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
4966 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4967
4968 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
4970
4971 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
4972
4973 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
4974 bool isATensor = isa<llvm::PointerType>(A->getType());
4975 args.push_back(A);
4976
4977 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
4978 args.push_back(mt.lookupValue(thisOp.getIdesc()));
4979 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
4980 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
4981
4982 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
4983 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
4984 using IsATensorArray = std::array<CtaGroupArray, 2>;
4985 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
4986 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
4987
4988 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
4989 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
4990 { // without diable output lane
4991 {{// without scale input D
4992 {{
4993 // shared
4994 {{// cg1
4995 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
4996 // cg2
4997 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
4998 {{// tensor
4999 {
5000 // cg1
5001 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5002 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5003 },
5004 {
5005 // cg2
5006 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5007 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5008 }}},
5009 }},
5010 // with scale input D
5011 {{ // shared
5012 {{// cg1
5013 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5014 notIntrinsic},
5015 // cg2
5016 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5017 notIntrinsic}}},
5018 {{// tensor
5019 {
5020 // cg1
5021 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5022 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5023 },
5024 {
5025 // cg2
5026 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5027 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5028 }}}}}}},
5029 // with disable output lane
5030 {{ // without scale input D
5031 {{ // shared
5032 {{// cg1
5033 {llvm::Intrinsic::
5034 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5035 notIntrinsic},
5036 // cg2
5037 {llvm::Intrinsic::
5038 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5039 notIntrinsic}}},
5040 {{// cg1
5041 {
5042 llvm::Intrinsic::
5043 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5044 llvm::Intrinsic::
5045 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5046 },
5047 // cg2
5048 {
5049 llvm::Intrinsic::
5050 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5051 llvm::Intrinsic::
5052 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5053 }}}}},
5054 // with scale input D
5055 {{ // shared
5056 {{// cg1
5057 {llvm::Intrinsic::
5058 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5059 notIntrinsic},
5060 // cg2
5061 {llvm::Intrinsic::
5062 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5063 notIntrinsic}}},
5064 // tensor
5065 {{// cg1
5066 {llvm::Intrinsic::
5067 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5068 llvm::Intrinsic::
5069 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5070 // cg2
5071 {
5072 llvm::Intrinsic::
5073 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5074 llvm::Intrinsic::
5075 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5076 }}}}}}}}};
5077
5078 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5079 bool hasScaleInputD = ScaleInputD != nullptr;
5080
5081 llvm::Value *DisableOutputLane =
5082 mt.lookupValue(thisOp.getDisableOutputLane());
5083 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5084
5085 unsigned ctaGroup =
5086 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5087
5088 llvm::Intrinsic::ID ID =
5089 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5090 [ctaGroup - 1][thisOp.getAShift()];
5091
5092 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
5093
5094 if (hasScaleInputD)
5095 args.push_back(ScaleInputD);
5096
5097 if (hasDisableOutputLane)
5098 args.push_back(DisableOutputLane);
5099
5100 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5101
5102 if (!hasDisableOutputLane)
5103 args.push_back(builder.getInt32(ctaGroup));
5104
5105 args.push_back(
5106 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5107
5108 return {ID, args};
5109}
5110
5111LogicalResult Tcgen05MMASparseOp::verify() {
5112 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5113 getDisableOutputLane(), getCtaGroup(), getAShift(),
5114 getCollectorOp(), getLoc());
5115}
5116
5117//===----------------------------------------------------------------------===//
5118// NVVM tcgen05.mma.block_scale functions
5119//===----------------------------------------------------------------------===//
5120
5121mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
5122 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5123
5124 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5126
5127 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5128
5129 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5130 bool isATensor = isa<llvm::PointerType>(A->getType());
5131 args.push_back(A);
5132
5133 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5134 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5135 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5136 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5137 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5138 args.push_back(builder.getInt32(
5139 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5140 args.push_back(
5141 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5142
5143 auto kind = thisOp.getKind();
5144 auto blockScale = thisOp.getBlockScale();
5145 llvm::Intrinsic::ID ID = [&]() {
5146 if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
5147 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5148 return isATensor ? llvm::Intrinsic::
5149 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5150 : llvm::Intrinsic::
5151 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5152 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5153 return isATensor
5154 ? llvm::Intrinsic::
5155 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5156 : llvm::Intrinsic::
5157 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5158 }
5159 } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
5160 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5161 return isATensor
5162 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5163 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5164 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5165 return isATensor ? llvm::Intrinsic::
5166 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5167 : llvm::Intrinsic::
5168 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5169 }
5170 } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
5171 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5172 return isATensor
5173 ? llvm::Intrinsic::
5174 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5175 : llvm::Intrinsic::
5176 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5177
5178 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5179 return isATensor
5180 ? llvm::Intrinsic::
5181 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5182 : llvm::Intrinsic::
5183 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5184 }
5185 }
5186 llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
5187 }();
5188
5189 return {ID, args};
5190}
5191
5192static LogicalResult verifyTcgen05MMABlockScaleOp(
5193 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind,
5194 NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
5195
5196 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5197 kind == MMABlockScaleKind::MXF4NVF4)
5198 return emitError(loc, "mxf4nvf4 requires block scale attribute");
5199
5200 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5201 kind != MMABlockScaleKind::MXF4NVF4)
5202 return emitError(loc,
5203 llvm::formatv("{} kind does not support block16 attribute",
5204 stringifyEnum(kind)));
5205
5206 return success();
5207}
5208
5209LogicalResult Tcgen05MMABlockScaleOp::verify() {
5210 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5211 getBlockScale(), getLoc());
5212}
5213
5214//===----------------------------------------------------------------------===//
5215// NVVM tcgen05.mma.sp.block_scale functions
5216//===----------------------------------------------------------------------===//
5217
5218mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
5219 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5220
5221 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5223
5224 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5225
5226 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5227 bool isATensor = isa<llvm::PointerType>(A->getType());
5228 args.push_back(A);
5229
5230 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5231 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5232 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5233 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5234 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5235 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5236 args.push_back(builder.getInt32(
5237 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5238 args.push_back(
5239 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5240
5241 auto kind = thisOp.getKind();
5242 auto blockScale = thisOp.getBlockScale();
5243 llvm::Intrinsic::ID ID = [&]() {
5244 if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
5245 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5246 return isATensor ? llvm::Intrinsic::
5247 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5248 : llvm::Intrinsic::
5249 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5250 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5251 return isATensor
5252 ? llvm::Intrinsic::
5253 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5254 : llvm::Intrinsic::
5255 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5256 }
5257 } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
5258 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5259 return isATensor ? llvm::Intrinsic::
5260 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5261 : llvm::Intrinsic::
5262 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5263 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5264 return isATensor
5265 ? llvm::Intrinsic::
5266 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5267 : llvm::Intrinsic::
5268 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5269 }
5270 } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
5271 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5272 return isATensor
5273 ? llvm::Intrinsic::
5274 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5275 : llvm::Intrinsic::
5276 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5277
5278 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5279 return isATensor
5280 ? llvm::Intrinsic::
5281 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5282 : llvm::Intrinsic::
5283 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5284 }
5285 }
5286 llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
5287 }();
5288
5289 return {ID, args};
5290}
5291
5292LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5293 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5294 getBlockScale(), getLoc());
5295}
5296
5297//===----------------------------------------------------------------------===//
5298// NVVM tcgen05.mma.ws functions
5299//===----------------------------------------------------------------------===//
5300
5301mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
5302 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5303
5304 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5306
5307 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5308
5309 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5310 bool isATensor = isa<llvm::PointerType>(A->getType());
5311 args.push_back(A);
5312
5313 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5314 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5315 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5316
5317 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5318 llvm::Intrinsic::ID ID = notIntrinsic;
5319 if (ZeroColMask) {
5320 args.push_back(mt.lookupValue(ZeroColMask));
5321 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
5322 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
5323 } else
5324 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
5325 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
5326
5327 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5328 args.push_back(
5329 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5330 args.push_back(
5331 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5332
5333 return {ID, args};
5334}
5335
5336//===----------------------------------------------------------------------===//
5337// NVVM tcgen05.mma.ws.sp functions
5338//===----------------------------------------------------------------------===//
5339
5340mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
5341 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5342
5343 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
5345
5346 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5347
5348 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5349 bool isATensor = isa<llvm::PointerType>(A->getType());
5350 args.push_back(A);
5351
5352 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5353 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5354 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5355 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5356
5357 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5358 llvm::Intrinsic::ID ID = notIntrinsic;
5359 if (ZeroColMask) {
5360 args.push_back(mt.lookupValue(ZeroColMask));
5361 ID = isATensor
5362 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
5363 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
5364 } else
5365 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
5366 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
5367
5368 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5369 args.push_back(
5370 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5371 args.push_back(
5372 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5373
5374 return {ID, args};
5375}
5376
5377//===----------------------------------------------------------------------===//
5378// NVVMDialect initialization, type parsing, and registration.
5379//===----------------------------------------------------------------------===//
5380
5381// TODO: This should be the llvm.nvvm dialect once this is supported.
5382void NVVMDialect::initialize() {
5383 addOperations<
5384#define GET_OP_LIST
5385#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5386 >();
5387 addAttributes<
5388#define GET_ATTRDEF_LIST
5389#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
5390 >();
5391
5392 // Support unknown operations because not all NVVM operations are
5393 // registered.
5394 allowUnknownOperations();
5395 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
5396 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
5397}
5398
5399LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
5400 NamedAttribute attr) {
5401 StringAttr attrName = attr.getName();
5402 // Kernel function attribute should be attached to functions.
5403 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
5404 if (!isa<LLVM::LLVMFuncOp>(op)) {
5405 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
5406 << "' attribute attached to unexpected op";
5407 }
5408 }
5409 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
5410 // dim
5411 if (attrName == NVVMDialect::getMaxntidAttrName() ||
5412 attrName == NVVMDialect::getReqntidAttrName() ||
5413 attrName == NVVMDialect::getClusterDimAttrName()) {
5414 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
5415 if (!values || values.empty() || values.size() > 3) {
5416 return op->emitError()
5417 << "'" << attrName
5418 << "' attribute must be integer array with maximum 3 index";
5419 }
5420 }
5421 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
5422 // attribute
5423 if (attrName == NVVMDialect::getMinctasmAttrName() ||
5424 attrName == NVVMDialect::getMaxnregAttrName() ||
5425 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
5426 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
5427 return op->emitError()
5428 << "'" << attrName << "' attribute must be integer constant";
5429 }
5430 }
5431 // blocksareclusters must be used along with reqntid and cluster_dim
5432 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
5433 if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
5434 !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
5435 return op->emitError()
5436 << "'" << attrName << "' attribute must be used along with "
5437 << "'" << NVVMDialect::getReqntidAttrName() << "' and "
5438 << "'" << NVVMDialect::getClusterDimAttrName() << "'";
5439 }
5440 }
5441
5442 return success();
5443}
5444
5445LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
5446 unsigned regionIndex,
5447 unsigned argIndex,
5448 NamedAttribute argAttr) {
5449 auto funcOp = dyn_cast<FunctionOpInterface>(op);
5450 if (!funcOp)
5451 return success();
5452
5453 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
5454 StringAttr attrName = argAttr.getName();
5455 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
5456 if (!isKernel) {
5457 return op->emitError()
5458 << "'" << attrName
5459 << "' attribute must be present only on kernel arguments";
5460 }
5461 if (!isa<UnitAttr>(argAttr.getValue()))
5462 return op->emitError() << "'" << attrName << "' must be a unit attribute";
5463 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
5464 return op->emitError()
5465 << "'" << attrName
5466 << "' attribute requires the argument to also have attribute '"
5467 << LLVM::LLVMDialect::getByValAttrName() << "'";
5468 }
5469 }
5470
5471 return success();
5472}
5473
5474//===----------------------------------------------------------------------===//
5475// NVVM Address Space Attr
5476//===----------------------------------------------------------------------===//
5477
5478unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
5479 return static_cast<unsigned>(getValue());
5480}
5481
5482bool NVVMMemorySpaceAttr::isValidLoad(
5483 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5484 const ::mlir::DataLayout *dataLayout,
5486 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
5487 dataLayout, emitError);
5488}
5489
5490bool NVVMMemorySpaceAttr::isValidStore(
5491 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5492 const ::mlir::DataLayout *dataLayout,
5494 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
5495 dataLayout, emitError);
5496}
5497
5498bool NVVMMemorySpaceAttr::isValidAtomicOp(
5499 ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
5500 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
5502 // TODO: update this method once `ptr.atomic_rmw` is implemented.
5503 assert(false && "unimplemented, see TODO in the source.");
5504 return false;
5505}
5506
5507bool NVVMMemorySpaceAttr::isValidAtomicXchg(
5508 Type type, ptr::AtomicOrdering successOrdering,
5509 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
5510 const ::mlir::DataLayout *dataLayout,
5512 // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
5513 assert(false && "unimplemented, see TODO in the source.");
5514 return false;
5515}
5516
5517bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
5518 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
5519 // TODO: update this method once the `ptr.addrspace_cast` op is added to the
5520 // dialect.
5521 assert(false && "unimplemented, see TODO in the source.");
5522 return false;
5523}
5524
5525bool NVVMMemorySpaceAttr::isValidPtrIntCast(
5526 Type intLikeTy, Type ptrLikeTy,
5528 // TODO: update this method once the int-cast ops are added to the `ptr`
5529 // dialect.
5530 assert(false && "unimplemented, see TODO in the source.");
5531 return false;
5532}
5533
5534//===----------------------------------------------------------------------===//
5535// NVVM target attribute.
5536//===----------------------------------------------------------------------===//
5537LogicalResult
5538NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
5539 int optLevel, StringRef triple, StringRef chip,
5540 StringRef features, DictionaryAttr flags,
5541 ArrayAttr files, bool verifyTarget) {
5542 if (optLevel < 0 || optLevel > 3) {
5543 emitError() << "The optimization level must be a number between 0 and 3.";
5544 return failure();
5545 }
5546 if (triple.empty()) {
5547 emitError() << "The target triple cannot be empty.";
5548 return failure();
5549 }
5550 if (chip.empty()) {
5551 emitError() << "The target chip cannot be empty.";
5552 return failure();
5553 }
5554 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
5555 return mlir::isa_and_nonnull<StringAttr>(attr);
5556 })) {
5557 emitError() << "All the elements in the `link` array must be strings.";
5558 return failure();
5559 }
5560 return success();
5561}
5562
5563LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
5564 if (!getVerifyTarget())
5565 return success();
5566
5567 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
5568 if (!gpuModuleOp) {
5569 return emitError(gpuModule->getLoc(),
5570 "NVVM target attribute must be attached to a GPU module");
5571 }
5572
5573 const NVVMCheckSMVersion targetSMVersion =
5575 if (!targetSMVersion.isMinimumSMVersion()) {
5576 return emitError(gpuModule->getLoc(),
5577 "Minimum NVVM target SM version is sm_20");
5578 }
5579
5580 if (gpuModuleOp
5581 ->walk([&](Operation *op) {
5582 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
5583 const NVVMCheckSMVersion requirement =
5584 reqOp.getRequiredMinSMVersion();
5585 if (!requirement.isCompatibleWith(targetSMVersion)) {
5586 op->emitOpError() << "is not supported on " << getChip();
5587 return WalkResult::interrupt();
5588 }
5589 }
5590 return WalkResult::advance();
5591 })
5592 .wasInterrupted())
5593 return failure();
5594
5595 return success();
5596}
5597
5598#define GET_OP_CLASSES
5599#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5600
5601#define GET_ATTRDEF_CLASSES
5602#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
#define _none
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &regs)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > &regTypes, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerType getI16Type()
Definition Builders.cpp:61
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:207
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition NVVMDialect.h:53
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
This represents an operation in an abstracted form, suitable for use with the builder APIs.