MLIR 23.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
39#include <cassert>
40#include <optional>
41#include <string>
42
43using namespace mlir;
44using namespace NVVM;
45
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
50
51//===----------------------------------------------------------------------===//
52// Helper/Utility methods
53//===----------------------------------------------------------------------===//
54
55static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
57 return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
58}
59
61 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
62}
63
65 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
66}
67
69 return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
70}
71
72static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
73 llvm::Value *ptr,
74 NVVMMemorySpace targetAS) {
75 unsigned AS = static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
78}
79
80// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
81static llvm::nvvm::CTAGroupKind
82getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
83 switch (ctaGroup) {
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
88 }
89 llvm_unreachable("unsupported cta_group value");
90}
91
92//===----------------------------------------------------------------------===//
93// Verifier methods
94//===----------------------------------------------------------------------===//
95
96// This verifier is shared among the following Ops:
97// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
98// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
99static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
100 bool isIm2Col,
101 size_t numIm2ColOffsets,
102 Location loc) {
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc, "expects coordinates between 1 to 5 dimension");
105
106 // For Im2Col mode, there are two constraints:
107 if (isIm2Col) {
108 // 1. Tensor must always be at least 3-d.
109 if (tensorDims < 3)
110 return emitError(
111 loc,
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
113 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
115 return emitError(
116 loc, "im2col offsets must be 2 less than number of coordinates");
117 }
118 return success();
119}
120
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
123 // We lower through inline-ptx when getPredicate() is true.
124 // a) Only TILE mode is supported
125 // b) Cache-hint is not supported
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError("Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
131 }
132
133 size_t dims = getCoordinates().size();
134 switch (mode) {
135 case TMAStoreMode::TILE:
136 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
137 case TMAStoreMode::IM2COL:
138 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
139 case TMAStoreMode::TILE_SCATTER4:
140 if (dims != 5)
141 return emitError("Scatter4 mode expects 5 coordinates");
142 }
143 return success();
144}
145
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError("Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError("expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError("CG cache modifier is only support for 16 bytes copy.");
154 return success();
155}
156
157// This verify params can be shared across TMA Load and Prefetch Ops.
158static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
159 TMALoadMode mode, Location loc) {
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc, "expects coordinates between 1 to 5 dimension");
162
163 auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
166 return emitError(loc)
167 << "to use " << mode
168 << " mode, the tensor has to be at least 3-dimensional";
169
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
172 << " (provided " << numIm2colOff << ")";
173
174 return success();
175 };
176
177 switch (mode) {
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode, false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode, true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode, true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode, false, 0)
188 : emitError(loc, "Gather4 mode expects 5 coordinates");
189 }
190 return success();
191}
192
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
194 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
195 getMode(), getLoc());
196}
197
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) { // Inline-asm based lowering
202 if (isCTAOnly)
203 return emitError("Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
205 return emitError(
206 "Predicate is supported only for Tile and Im2col modes.");
207 } else { // Intrinsics-based lowering
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
211 .getAddressSpace();
212 if (AS != expectedAS)
213 return emitError()
214 << (isCTAOnly
215 ? "Shared::cta destination requires address-space 3."
216 : "Shared::cluster destination requires address-space 7.");
217 // Checks specific to shared::cta mode
218 if (isCTAOnly) {
219 if (getMulticastMask())
220 return emitError("Multicast is not supported with shared::cta mode.");
221 if (getGroup())
222 return emitError("CTAGroup is not supported with shared::cta mode.");
223 }
224 }
225
226 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
227 getMode(), getLoc());
228}
229
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
232 size_t dims = getCoordinates().size();
233 switch (mode) {
234 case TMAStoreMode::TILE:
235 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
236 case TMAStoreMode::IM2COL:
237 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
240 }
241 return success();
242}
243
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
245 bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
246 if (isSharedCTA && getMulticastMask())
247 return emitError("Multicast is not supported with shared::cta mode.");
248
249 return success();
250}
251
252static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
253 NVVM::MemScopeKind scope,
254 Value retVal = nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->emitError("mbarrier scope must be either CTA or Cluster");
257
258 bool isSharedCluster = isPtrInSharedClusterSpace(addr);
259 bool hasRetValue = static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
261 return op->emitError(
262 "mbarrier in shared_cluster space cannot return any value");
263
264 return success();
265}
266
267LogicalResult MBarrierArriveOp::verify() {
268 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
269 getRes());
270}
271
272LogicalResult MBarrierArriveDropOp::verify() {
273 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
274 getRes());
275}
276
277LogicalResult MBarrierArriveExpectTxOp::verify() {
278 // The inline-ptx version of this Op does not support all features.
279 // With predicate, this Op lowers to inline-ptx. So, verify and
280 // error-out if there are unsupported features.
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError("mbarrier scope must be CTA when using predicate");
284
285 if (isPtrInSharedClusterSpace(getAddr()))
286 return emitError("mbarrier in shared_cluster space is not supported when "
287 "using predicate");
288
289 if (getRes())
290 return emitError("return-value is not supported when using predicate");
291
292 if (getRelaxed() == true)
293 return emitError("mbarrier with relaxed semantics is not supported when "
294 "using predicate");
295 }
296 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
297 getRes());
298}
299
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
301 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
302 getRes());
303}
304
305LogicalResult MBarrierExpectTxOp::verify() {
306 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
307}
308
309LogicalResult MBarrierCompleteTxOp::verify() {
310 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
311}
312
313LogicalResult MBarrierTestWaitOp::verify() {
314 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
315}
316
317LogicalResult MBarrierTryWaitOp::verify() {
318 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
319}
320
321LogicalResult ConvertFloatToTF32Op::verify() {
322 using RndMode = NVVM::FPRoundingMode;
323 switch (getRnd()) {
324 case RndMode::RNA:
325 if (getRelu())
326 return emitError("Relu not supported with rna rounding mode.");
327 break;
328 case RndMode::RN:
329 case RndMode::RZ:
330 break;
331 default:
332 return emitError(
333 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
334 }
335 return success();
336}
337
338LogicalResult ConvertF32x2ToF6x2Op::verify() {
340
341 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
342 return emitOpError("Only ")
343 << mlir::Float6E2M3FNType::get(ctx) << " and "
344 << mlir::Float6E3M2FNType::get(ctx)
345 << " types are supported for conversions from f32x2 to f6x2.";
346 }
347 return success();
348}
349
350LogicalResult ConvertF32x2ToF8x2Op::verify() {
351 using RndMode = NVVM::FPRoundingMode;
352 using SatMode = NVVM::SaturationMode;
353
354 bool isRoundingModeRN = getRnd() == RndMode::RN;
355 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
356 bool isRoundingModeRP = getRnd() == RndMode::RP;
357 bool isSatFinite = getSat() == SatMode::SATFINITE;
358
359 bool hasRelu = getRelu();
360
362
364 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
365 [&](mlir::Type) -> LogicalResult {
366 if (!isRoundingModeRN) {
367 return emitOpError("Only RN rounding mode is supported for "
368 "conversions from f32x2 to ")
369 << mlir::Float8E4M3FNType::get(ctx) << " and "
370 << mlir::Float8E5M2Type::get(ctx) << " types";
371 }
372 if (!isSatFinite) {
373 return emitOpError("Only SATFINITE saturation mode is supported "
374 "for conversions "
375 "from f32x2 to ")
376 << mlir::Float8E4M3FNType::get(ctx) << " and "
377 << mlir::Float8E5M2Type::get(ctx) << " types";
378 }
379 return success();
380 })
381 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
382 if (!(isRoundingModeRZ || isRoundingModeRP)) {
383 return emitOpError("Only RZ and RP rounding modes are supported for "
384 "conversions from f32x2 to ")
385 << mlir::Float8E8M0FNUType::get(ctx) << " type";
386 }
387 if (hasRelu) {
388 return emitOpError("relu not supported for conversions to ")
389 << mlir::Float8E8M0FNUType::get(ctx) << " type";
390 }
391 return success();
392 })
393 .Default([&](mlir::Type) {
394 return emitOpError("Only ")
395 << mlir::Float8E4M3FNType::get(ctx) << ", "
396 << mlir::Float8E5M2Type::get(ctx) << ", and "
397 << mlir::Float8E8M0FNUType::get(ctx)
398 << " types are "
399 "supported for conversions from f32x2 to f8x2";
400 });
401}
402
403LogicalResult ConvertF16x2ToF8x2Op::verify() {
405
406 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
407 return emitOpError("Only ")
408 << mlir::Float8E4M3FNType::get(ctx) << " and "
409 << mlir::Float8E5M2Type::get(ctx)
410 << " types are supported for conversions from f16x2 to f8x2.";
411 }
412 return success();
413}
414
415LogicalResult ConvertBF16x2ToF8x2Op::verify() {
416 using RndMode = NVVM::FPRoundingMode;
417
418 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
419 return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
420 << " type is supported for conversions from "
421 "bf16x2 to f8x2.";
422
423 auto rnd = getRnd();
424 if (rnd != RndMode::RZ && rnd != RndMode::RP)
425 return emitOpError("Only RZ and RP rounding modes are supported for "
426 "conversions from bf16x2 to f8x2.");
427
428 return success();
429}
430
431LogicalResult ConvertF32x2ToF4x2Op::verify() {
433
434 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
435 return emitOpError("Only ")
436 << mlir::Float4E2M1FNType::get(ctx)
437 << " type is supported for conversions from f32x2 to f4x2.";
438
439 return success();
440}
441
442LogicalResult ConvertF8x2ToF16x2Op::verify() {
444
445 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
446 return emitOpError("Only ")
447 << mlir::Float8E4M3FNType::get(ctx) << " and "
448 << mlir::Float8E5M2Type::get(ctx)
449 << " types are supported for conversions from f8x2 to f16x2.";
450
451 return success();
452}
453
454LogicalResult ConvertF8x2ToBF16x2Op::verify() {
456 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
457 return emitOpError("Only ")
458 << mlir::Float8E8M0FNUType::get(ctx)
459 << " type is supported for conversions from f8x2 to bf16x2.";
460
461 return success();
462}
463
464LogicalResult ConvertF6x2ToF16x2Op::verify() {
466
467 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
468 return emitOpError("Only ")
469 << mlir::Float6E2M3FNType::get(ctx) << " and "
470 << mlir::Float6E3M2FNType::get(ctx)
471 << " types are supported for conversions from f6x2 to f16x2.";
472
473 return success();
474}
475
476LogicalResult ConvertF4x2ToF16x2Op::verify() {
478
479 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
480 return emitOpError("Only ")
481 << mlir::Float4E2M1FNType::get(ctx)
482 << " type is supported for conversions from f4x2 to f16x2.";
483
484 return success();
485}
486
487LogicalResult PermuteOp::verify() {
488 using Mode = NVVM::PermuteMode;
489 bool hasHi = static_cast<bool>(getHi());
490
491 switch (getMode()) {
492 case Mode::DEFAULT:
493 case Mode::F4E:
494 case Mode::B4E:
495 if (!hasHi)
496 return emitError("mode '") << getMode() << "' requires 'hi' operand.";
497 break;
498 case Mode::RC8:
499 case Mode::ECL:
500 case Mode::ECR:
501 case Mode::RC16:
502 if (hasHi)
503 return emitError("mode '")
504 << getMode() << "' does not accept 'hi' operand.";
505 break;
506 }
507
508 return success();
509}
510
511//===----------------------------------------------------------------------===//
512// Stochastic Rounding Conversion Ops
513//===----------------------------------------------------------------------===//
514
515static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
516 FPRoundingMode rnd,
517 bool hasRandomBits,
518 Operation *op) {
519 static constexpr FPRoundingMode validRndModes[] = {
520 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
521
522 if (!llvm::is_contained(validRndModes, rnd)) {
523 return op->emitOpError(
524 "Only RN, RZ, and RS rounding modes are supported for "
525 "conversions from f32x2 to ")
526 << dstType << ".";
527 }
528
529 if (rnd == FPRoundingMode::RS) {
530 if (!hasRandomBits) {
531 return op->emitOpError("random_bits is required for RS rounding mode.");
532 }
533 } else {
534 if (hasRandomBits) {
535 return op->emitOpError(
536 "random_bits not supported for RN and RZ rounding modes.");
537 }
538 }
539
540 return success();
541}
542
543LogicalResult ConvertF32x2ToF16x2Op::verify() {
544 return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
545 getRandomBits() ? true : false, *this);
546}
547
548LogicalResult ConvertF32x2ToBF16x2Op::verify() {
549 return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
550 getRandomBits() ? true : false, *this);
551}
552
553LogicalResult ConvertF32x4ToF8x4Op::verify() {
555
556 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
557 return emitOpError("Only ")
558 << mlir::Float8E4M3FNType::get(ctx) << " and "
559 << mlir::Float8E5M2Type::get(ctx)
560 << " types are supported for conversions from f32x4 to f8x4.";
561
562 return success();
563}
564
565LogicalResult ConvertF32x4ToF6x4Op::verify() {
567
568 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
569 return emitOpError("Only ")
570 << mlir::Float6E2M3FNType::get(ctx) << " and "
571 << mlir::Float6E3M2FNType::get(ctx)
572 << " types are supported for conversions from f32x4 to f6x4.";
573
574 return success();
575}
576
577LogicalResult ConvertF32x4ToF4x4Op::verify() {
579
580 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
581 return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
582 << " type is supported for conversions from "
583 "f32x4 to f4x4.";
584
585 return success();
586}
587
588LogicalResult BulkStoreOp::verify() {
589 if (getInitVal() != 0)
590 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
591 return success();
592}
593
594LogicalResult PMEventOp::verify() {
595 auto eventId = getEventId();
596 auto maskedEventId = getMaskedEventId();
597 if (!maskedEventId && !eventId) {
598 return emitOpError() << "either `id` or `mask` must be set";
599 }
600
601 if (maskedEventId && eventId) {
602 return emitOpError() << "`id` and `mask` cannot be set at the same time";
603 }
604
605 if (eventId) {
606 if (eventId < 0 || eventId > 15) {
607 return emitOpError() << "`id` must be between 0 and 15";
608 }
609 }
610
611 return llvm::success();
612}
613
614// Given the element type of an operand and whether or not it is an accumulator,
615// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
616// operand's element type.
617std::optional<mlir::NVVM::MMATypes>
618MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
619 auto half2Type =
620 VectorType::get(2, Float16Type::get(operandElType.getContext()));
621 if (operandElType.isF64())
622 return NVVM::MMATypes::f64;
623 if (operandElType.isF16() || operandElType == half2Type)
624 return NVVM::MMATypes::f16;
625 if (operandElType.isF32() && isAccumulator)
626 return NVVM::MMATypes::f32;
627 if (operandElType.isF32() && !isAccumulator)
628 return NVVM::MMATypes::tf32;
629 if (llvm::isa<IntegerType>(operandElType)) {
630 if (isAccumulator)
631 return NVVM::MMATypes::s32;
632 return std::nullopt;
633 }
634
635 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
636 if (structType.getBody().empty())
637 return std::nullopt;
638 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
639 }
640
641 return std::nullopt;
642}
643
644static bool isInt4PtxType(MMATypes type) {
645 return (type == MMATypes::u4 || type == MMATypes::s4);
646}
647
648static bool isInt8PtxType(MMATypes type) {
649 return (type == MMATypes::u8 || type == MMATypes::s8);
650}
651
652static bool isIntegerPtxType(MMATypes type) {
653 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
654 type == MMATypes::s32;
655}
656
657MMATypes MmaOp::accumPtxType() {
658 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
659 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
660 assert(val.has_value() && "accumulator PTX type should always be inferrable");
661 return val.value();
662}
663
664MMATypes MmaOp::resultPtxType() {
665 std::optional<mlir::NVVM::MMATypes> val =
666 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
667 assert(val.has_value() && "result PTX type should always be inferrable");
668 return val.value();
669}
670
671void MmaOp::print(OpAsmPrinter &p) {
672 SmallVector<Type, 4> regTypes;
673 struct MMAOperandFragment {
674 StringRef operandName;
675 StringRef ptxTypeAttr;
676 SmallVector<Value, 4> regs;
677 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
678 : operandName(name), ptxTypeAttr(ptxTypeName) {}
679 };
680
681 std::array<MMAOperandFragment, 3> frags{
682 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
683 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
684 MMAOperandFragment("C", "")};
685 SmallVector<StringRef, 4> ignoreAttrNames{
686 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
687
688 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
689 auto &frag = frags[fragIdx];
690 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
691 for (auto operandIdx = varOperandSpec.first;
692 operandIdx < varOperandSpec.first + varOperandSpec.second;
693 operandIdx++) {
694 frag.regs.push_back(this->getOperand(operandIdx));
695 if (operandIdx == 0) {
696 regTypes.push_back(this->getOperand(operandIdx).getType());
697 }
698 }
699 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
700 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
701 if (inferredType)
702 ignoreAttrNames.push_back(frag.ptxTypeAttr);
703 }
704
705 auto printMmaOperand = [&](const MMAOperandFragment &frag) -> void {
706 p << " " << frag.operandName;
707 p << "[";
708 p.printOperands(frag.regs);
709 p << "] ";
710 };
711
712 for (const auto &frag : frags) {
713 printMmaOperand(frag);
714 }
715
716 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
717
718 // Print the types of the operands and result.
719 p << " : "
720 << "(";
721 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
722 frags[1].regs[0].getType(),
723 frags[2].regs[0].getType()},
724 p);
725 p << ")";
726 p.printArrowTypeList(TypeRange{this->getRes().getType()});
727}
728
729void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
730 ValueRange operandA, ValueRange operandB, ValueRange operandC,
731 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
732 std::optional<MMAIntOverflow> intOverflow,
733 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
734 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
735
736 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
737 MLIRContext *ctx = builder.getContext();
738 result.addAttribute(
739 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
740
741 result.addOperands(operandA);
742 result.addOperands(operandB);
743 result.addOperands(operandC);
744
745 if (multiplicandPtxTypes) {
746 result.addAttribute("multiplicandAPtxType",
747 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
748 result.addAttribute("multiplicandBPtxType",
749 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
750 } else {
751 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
752 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
753 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
754 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
755 }
756
757 if (multiplicandLayouts) {
758 result.addAttribute("layoutA",
759 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
760 result.addAttribute("layoutB",
761 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
762 } else {
763 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
764 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
765 }
766
767 if (intOverflow.has_value())
768 result.addAttribute("intOverflowBehavior",
769 MMAIntOverflowAttr::get(ctx, *intOverflow));
770 if (b1Op.has_value())
771 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
772
773 result.addTypes(resultType);
774 result.addAttribute(
775 MmaOp::getOperandSegmentSizeAttr(),
776 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
777 static_cast<int32_t>(operandB.size()),
778 static_cast<int32_t>(operandC.size())}));
779}
780
781// <operation> :=
782// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
783// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
784// `->` type($res)
785ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
786 struct MMAOperandFragment {
787 std::optional<MMATypes> elemtype;
788 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
789 SmallVector<Type> regTypes;
790 };
791
792 Builder &builder = parser.getBuilder();
793 std::array<MMAOperandFragment, 4> frags;
794
795 NamedAttrList namedAttributes;
796
797 // A helper to parse the operand segments.
798 auto parseMmaOperand = [&](StringRef operandName,
799 MMAOperandFragment &frag) -> LogicalResult {
800 if (parser.parseKeyword(operandName).failed())
801 return failure();
802 if (parser
803 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
804 .failed())
805 return failure();
806 return success();
807 };
808
809 // Parse the operand segments.
810 if (parseMmaOperand("A", frags[0]).failed())
811 return failure();
812 if (parseMmaOperand("B", frags[1]).failed())
813 return failure();
814 if (parseMmaOperand("C", frags[2]).failed())
815 return failure();
816
817 if (parser.parseOptionalAttrDict(namedAttributes).failed())
818 return failure();
819
820 // Parse the type specification and resolve operands.
821 SmallVector<Type, 3> operandTypes;
822 if (failed(parser.parseColon()))
823 return failure();
824 if (failed(parser.parseLParen()))
825 return failure();
826 if (failed(parser.parseTypeList(operandTypes)))
827 return failure();
828 if (failed(parser.parseRParen()))
829 if (operandTypes.size() != 3)
830 return parser.emitError(
831 parser.getNameLoc(),
832 "expected one type for each operand segment but got " +
833 Twine(operandTypes.size()) + " types");
834 for (const auto &iter : llvm::enumerate(operandTypes)) {
835 auto &frag = frags[iter.index()];
836 frag.regTypes.resize(frag.regs.size(), iter.value());
837 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
838 parser.getNameLoc(), result.operands)))
839 return failure();
840 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
841 /*isAccumulator*/ iter.index() < 2);
842 }
843
844 Type resultType;
845 if (parser.parseArrow() || parser.parseType(resultType))
846 return failure();
847 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
848
849 std::array<StringRef, 2> names{"multiplicandAPtxType",
850 "multiplicandBPtxType"};
851 for (unsigned idx = 0; idx < names.size(); idx++) {
852 const auto &frag = frags[idx];
853 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
854 if (!frag.elemtype.has_value() && !attr.has_value()) {
855 return parser.emitError(
856 parser.getNameLoc(),
857 "attribute " + names[idx] +
858 " is not provided explicitly and cannot be inferred");
859 }
860 if (!attr.has_value())
861 result.addAttribute(
862 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
863 }
864
865 result.addTypes(resultType);
866 if (!namedAttributes.empty())
867 result.addAttributes(namedAttributes);
868 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
869 builder.getDenseI32ArrayAttr({
870 static_cast<int32_t>(frags[0].regs.size()),
871 static_cast<int32_t>(frags[1].regs.size()),
872 static_cast<int32_t>(frags[2].regs.size()),
873 }));
874 return success();
875}
876
877LogicalResult MmaOp::verify() {
878 MLIRContext *context = getContext();
879 auto f16Ty = Float16Type::get(context);
880 auto i32Ty = IntegerType::get(context, 32);
881 auto f16x2Ty = VectorType::get(2, f16Ty);
882 auto f32Ty = Float32Type::get(context);
883 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
884 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
885
886 auto s32x4StructTy =
887 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
888 auto f32x8StructTy =
889 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
890 auto f16x2x2StructTy =
891 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
892 auto f32x4StructTy =
893 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
894 auto s32x2StructTy =
895 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
896
897 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
898 getShapeAttr().getK()};
899
900 // These variables define the set of allowed data types for matrices A, B, C,
901 // and result.
902 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
903 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
904 AllowedShapes allowedShapes;
905 AllowedTypes expectedA;
906 AllowedTypes expectedB;
907 AllowedTypes expectedC;
908 SmallVector<Type> expectedResult;
909
910 // When M = 16, we just need to calculate the number of 8xk tiles, where
911 // k is a factor that depends on the data type.
912 if (mmaShape[0] == 16) {
913 int64_t kFactor;
914 Type multiplicandFragType;
915 switch (*getMultiplicandAPtxType()) {
916 case MMATypes::tf32:
917 kFactor = 4;
918 multiplicandFragType = i32Ty;
919 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
920 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
921 break;
922 case MMATypes::bf16:
923 kFactor = 8;
924 multiplicandFragType = i32Ty;
925 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
926 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
927 break;
928 case MMATypes::f16:
929 kFactor = 8;
930 multiplicandFragType = f16x2Ty;
931 expectedResult.push_back(f16x2x2StructTy);
932 expectedResult.push_back(f32x4StructTy);
933 break;
934 case MMATypes::s4:
935 case MMATypes::u4:
936 kFactor = 32;
937 break;
938 case MMATypes::b1:
939 kFactor = 128;
940 break;
941 case MMATypes::s8:
942 case MMATypes::u8:
943 kFactor = 16;
944 break;
945 default:
946 return emitError("invalid shape or multiplicand type: ")
947 << getMultiplicandAPtxType().value();
948 }
949
950 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
951 expectedResult.push_back(s32x4StructTy);
952 expectedC.emplace_back(4, i32Ty);
953 multiplicandFragType = i32Ty;
954 } else {
955 expectedC.emplace_back(2, f16x2Ty);
956 expectedC.emplace_back(4, f32Ty);
957 }
958
959 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
960 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
961 expectedA.emplace_back(unitA, multiplicandFragType);
962 expectedB.emplace_back(unitB, multiplicandFragType);
963 allowedShapes.push_back({16, 8, kFactor});
964 allowedShapes.push_back({16, 8, kFactor * 2});
965
966 if (resultPtxType() != accumPtxType())
967 return emitOpError("ctype does not match dtype");
968 }
969
970 // In the M=8 case, there is only 1 possible case per data type.
971 if (mmaShape[0] == 8) {
972 if (*getMultiplicandAPtxType() == MMATypes::f16) {
973 expectedA.emplace_back(2, f16x2Ty);
974 expectedB.emplace_back(2, f16x2Ty);
975 expectedResult.push_back(f16x2x4StructTy);
976 expectedResult.push_back(f32x8StructTy);
977 expectedC.emplace_back(4, f16x2Ty);
978 expectedC.emplace_back(8, f32Ty);
979 allowedShapes.push_back({8, 8, 4});
980 }
981 if (*getMultiplicandAPtxType() == MMATypes::f64) {
982 Type f64Ty = Float64Type::get(context);
983 expectedA.emplace_back(1, f64Ty);
984 expectedB.emplace_back(1, f64Ty);
985 expectedC.emplace_back(2, f64Ty);
986 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
987 context, SmallVector<Type>(2, f64Ty)));
988 allowedShapes.push_back({8, 8, 4});
989 }
990 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
991 expectedA.push_back({i32Ty});
992 expectedB.push_back({i32Ty});
993 expectedC.push_back({i32Ty, i32Ty});
994 expectedResult.push_back(s32x2StructTy);
995 if (isInt4PtxType(getMultiplicandAPtxType().value()))
996 allowedShapes.push_back({8, 8, 32});
997 if (isInt8PtxType(getMultiplicandAPtxType().value()))
998 allowedShapes.push_back({8, 8, 16});
999 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1000 allowedShapes.push_back({8, 8, 128});
1001 }
1002 }
1003
1004 std::string errorMessage;
1005 llvm::raw_string_ostream errorStream(errorMessage);
1006
1007 // Check that we matched an existing shape/dtype combination.
1008 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1009 !llvm::is_contained(allowedShapes, mmaShape)) {
1010 errorStream << "unimplemented variant for MMA shape <";
1011 llvm::interleaveComma(mmaShape, errorStream);
1012 errorStream << ">";
1013 return emitOpError(errorMessage);
1014 }
1015
1016 // Verify the operand types for segments of A, B, and C operands.
1017 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1018 for (const auto &iter : llvm::enumerate(
1019 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1020 auto spec = this->getODSOperandIndexAndLength(iter.index());
1021 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1022 operand_type_begin() + spec.first +
1023 spec.second);
1024 bool match = llvm::is_contained(iter.value(), operandTySeg);
1025
1026 if (!match) {
1027 errorStream << "Could not match types for the "
1028 << operandNames[iter.index()]
1029 << " operands; expected one of ";
1030 for (const auto &x : iter.value()) {
1031 errorStream << x.size() << "x" << x[0] << " ";
1032 }
1033 errorStream << "but got ";
1034 llvm::interleaveComma(operandTySeg, errorStream);
1035 return emitOpError(errorMessage);
1036 }
1037 }
1038
1039 // Check the result type
1040 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1041 return expectedResultType == getResult().getType();
1042 })) {
1043 errorStream
1044 << "Could not match allowed types for the result; expected one of ";
1045 llvm::interleaveComma(expectedResult, errorStream);
1046 errorStream << " but got " << getResult().getType();
1047 return emitOpError(errorMessage);
1048 }
1049
1050 // Ensure that binary MMA variants have a b1 MMA operation defined.
1051 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1052 return emitOpError("op requires " + getB1OpAttrName().strref() +
1053 " attribute");
1054 }
1055
1056 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1057 // attribute.
1058 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1059 isInt8PtxType(*getMultiplicandAPtxType())) {
1060 if (!getIntOverflowBehavior())
1061 return emitOpError("op requires " +
1062 getIntOverflowBehaviorAttrName().strref() +
1063 " attribute");
1064 }
1065
1066 // Validate layout combinations. According to the operation description, most
1067 // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
1068 // can use other layout combinations.
1069 bool isM8N8K4_F16 =
1070 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1071 getMultiplicandAPtxType() == MMATypes::f16);
1072
1073 if (!isM8N8K4_F16) {
1074 // For all other shapes/types, layoutA must be row and layoutB must be col
1075 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1076 return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
1077 "layoutB = #nvvm.mma_layout<col> for shape <")
1078 << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
1079 << "> with element types " << *getMultiplicandAPtxType() << " and "
1080 << *getMultiplicandBPtxType()
1081 << ". Only m8n8k4 with f16 supports other layouts.";
1082 }
1083 }
1084
1085 return success();
1086}
1087
1088MMATypes MmaSpOp::accumPtxType() {
1089 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1090 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
1091 assert(val.has_value() && "accumulator PTX type should always be inferrable");
1092 return val.value();
1093}
1094
1095MMATypes MmaSpOp::resultPtxType() {
1096 std::optional<mlir::NVVM::MMATypes> val =
1097 MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
1098 assert(val.has_value() && "result PTX type should always be inferrable");
1099 return val.value();
1100}
1101
1103MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1104 llvm::IRBuilderBase &builder) {
1105 auto thisOp = cast<NVVM::MmaSpOp>(op);
1106
1107 // Get operands
1109 for (mlir::Value v : thisOp.getOperands())
1110 args.push_back(mt.lookupValue(v));
1111
1112 // Get intrinsic ID using the existing getIntrinsicID method
1113 auto intId = MmaSpOp::getIntrinsicID(
1114 thisOp.getShape().getM(), thisOp.getShape().getN(),
1115 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1116 thisOp.getOrderedMetadata(), thisOp.getKind(),
1117 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1118 thisOp.accumPtxType(), thisOp.resultPtxType());
1119
1120 return {intId, args};
1121}
1122
1123void MmaSpOp::print(OpAsmPrinter &p) {
1124 SmallVector<Type, 4> regTypes;
1125 struct MMAOperandFragment {
1126 StringRef operandName;
1127 StringRef ptxTypeAttr;
1128 SmallVector<Value, 4> regs;
1129 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1130 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1131 };
1132
1133 std::array<MMAOperandFragment, 5> frags{
1134 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1135 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1136 MMAOperandFragment("C", ""), MMAOperandFragment("sparseMetadata", ""),
1137 MMAOperandFragment("selector", "")};
1138 SmallVector<StringRef, 4> ignoreAttrNames{
1139 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1140
1141 // Handle variadic operands A, B, C
1142 for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1143 auto &frag = frags[fragIdx];
1144 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1145 for (auto operandIdx = varOperandSpec.first;
1146 operandIdx < varOperandSpec.first + varOperandSpec.second;
1147 operandIdx++) {
1148 frag.regs.push_back(this->getOperand(operandIdx));
1149 if (operandIdx == varOperandSpec.first) {
1150 regTypes.push_back(this->getOperand(operandIdx).getType());
1151 }
1152 }
1153 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1154 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
1155 if (inferredType)
1156 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1157 }
1158
1159 // Handle sparse metadata and selector (single operands)
1160 frags[3].regs.push_back(getSparseMetadata());
1161 frags[4].regs.push_back(getSparsitySelector());
1162
1163 auto printMmaSpOperand = [&](const MMAOperandFragment &frag) -> void {
1164 p << " " << frag.operandName;
1165 p << "[";
1166 p.printOperands(frag.regs);
1167 p << "]";
1168 };
1169
1170 for (const auto &frag : frags)
1171 printMmaSpOperand(frag);
1172
1173 p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
1174 p << " : ";
1175 p << "(";
1176 for (int i = 0; i < 3; ++i) {
1177 p << regTypes[i];
1178 if (i < 2)
1179 p << ", ";
1180 }
1181 p << ") -> " << getResult().getType();
1182}
1183
1184void MmaSpOp::build(
1185 OpBuilder &builder, OperationState &result, Type resultType,
1186 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1187 Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
1188 std::optional<MMAIntOverflow> intOverflow,
1189 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1190
1191 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1192 MLIRContext *ctx = builder.getContext();
1193 result.addAttribute(
1194 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1195
1196 result.addOperands(operandA);
1197 result.addOperands(operandB);
1198 result.addOperands(operandC);
1199 result.addOperands(sparseMetadata);
1200 result.addOperands(sparsitySelector);
1201
1202 if (multiplicandPtxTypes) {
1203 result.addAttribute("multiplicandAPtxType",
1204 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1205 result.addAttribute("multiplicandBPtxType",
1206 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1207 } else {
1208 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1209 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1210 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1211 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1212 }
1213
1214 if (intOverflow.has_value())
1215 result.addAttribute("intOverflowBehavior",
1216 MMAIntOverflowAttr::get(ctx, *intOverflow));
1217
1218 result.addTypes(resultType);
1219 result.addAttribute(
1220 MmaSpOp::getOperandSegmentSizeAttr(),
1221 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
1222 static_cast<int32_t>(operandB.size()),
1223 static_cast<int32_t>(operandC.size()), 1,
1224 1})); // sparseMetadata and sparsitySelector
1225}
1226
1227ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
1228 struct MMAOperandFragment {
1229 std::optional<MMATypes> elemtype;
1230 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1231 SmallVector<Type> regTypes;
1232 };
1233
1234 Builder &builder = parser.getBuilder();
1235 std::array<MMAOperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
1236
1237 NamedAttrList namedAttributes;
1238
1239 // A helper to parse the operand segments.
1240 auto parseMmaSpOperand = [&](StringRef operandName,
1241 MMAOperandFragment &frag) -> LogicalResult {
1242 if (parser.parseKeyword(operandName).failed())
1243 return failure();
1244 if (parser
1245 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
1246 .failed())
1247 return failure();
1248 return success();
1249 };
1250
1251 // Parse the operand segments.
1252 if (parseMmaSpOperand("A", frags[0]).failed())
1253 return failure();
1254 if (parseMmaSpOperand("B", frags[1]).failed())
1255 return failure();
1256 if (parseMmaSpOperand("C", frags[2]).failed())
1257 return failure();
1258 if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
1259 return failure();
1260 if (parseMmaSpOperand("selector", frags[4]).failed())
1261 return failure();
1262
1263 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1264 return failure();
1265
1266 // Parse the type specification and resolve operands.
1267 SmallVector<Type, 3> operandTypes;
1268 if (failed(parser.parseColon()))
1269 return failure();
1270 if (failed(parser.parseLParen()))
1271 return failure();
1272 if (failed(parser.parseTypeList(operandTypes)))
1273 return failure();
1274 if (failed(parser.parseRParen()))
1275 return failure();
1276 if (operandTypes.size() != 3)
1277 return parser.emitError(
1278 parser.getNameLoc(),
1279 "expected one type for each operand segment but got " +
1280 Twine(operandTypes.size()) + " types");
1281 for (const auto &iter : llvm::enumerate(operandTypes)) {
1282 auto &frag = frags[iter.index()];
1283 frag.regTypes.resize(frag.regs.size(), iter.value());
1284 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
1285 parser.getNameLoc(), result.operands)))
1286 return failure();
1287 frag.elemtype =
1288 MmaOp::inferOperandMMAType(frag.regTypes[0],
1289 /*isAccumulator*/ iter.index() >= 2);
1290 }
1291
1292 Type resultType;
1293 if (parser.parseArrow() || parser.parseType(resultType))
1294 return failure();
1295 frags[5].elemtype =
1296 MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
1297
1298 // Resolve sparse metadata and selector (assume i32 type)
1299 Type i32Type = builder.getIntegerType(32);
1300 if (parser
1301 .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
1302 result.operands)
1303 .failed())
1304 return failure();
1305 if (parser
1306 .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
1307 result.operands)
1308 .failed())
1309 return failure();
1310
1311 std::array<StringRef, 2> names{"multiplicandAPtxType",
1312 "multiplicandBPtxType"};
1313 for (unsigned idx = 0; idx < names.size(); idx++) {
1314 const auto &frag = frags[idx];
1315 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
1316 if (!frag.elemtype.has_value() && !attr.has_value()) {
1317 return parser.emitError(
1318 parser.getNameLoc(),
1319 "attribute " + names[idx] +
1320 " is not provided explicitly and cannot be inferred");
1321 }
1322 if (!attr.has_value())
1323 result.addAttribute(
1324 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
1325 }
1326
1327 result.addTypes(resultType);
1328 if (!namedAttributes.empty())
1329 result.addAttributes(namedAttributes);
1330 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1331 builder.getDenseI32ArrayAttr({
1332 static_cast<int32_t>(frags[0].regs.size()),
1333 static_cast<int32_t>(frags[1].regs.size()),
1334 static_cast<int32_t>(frags[2].regs.size()),
1335 1, // sparseMetadata
1336 1 // sparsitySelector
1337 }));
1338 return success();
1339}
1340
1341LogicalResult MmaSpOp::verify() {
1342 MLIRContext *context = getContext();
1343 auto f16Ty = Float16Type::get(context);
1344 auto i32Ty = IntegerType::get(context, 32);
1345 auto f16x2Ty = VectorType::get(2, f16Ty);
1346 auto f32Ty = Float32Type::get(context);
1347 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1348 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1349
1350 auto s32x4StructTy =
1351 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1352 auto f32x8StructTy =
1353 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
1354 auto f16x2x2StructTy =
1355 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1356 auto f32x4StructTy =
1357 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1358 auto s32x2StructTy =
1359 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1360
1361 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1362 getShapeAttr().getK()};
1363
1364 // These variables define the set of allowed data types for matrices A, B, C,
1365 // and result.
1366 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
1367 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
1368 AllowedShapes allowedShapes;
1369 AllowedTypes expectedA;
1370 AllowedTypes expectedB;
1371 AllowedTypes expectedC;
1372 SmallVector<Type> expectedResult;
1373
1374 // When M = 16, we just need to calculate the number of 8xk tiles, where
1375 // k is a factor that depends on the data type.
1376 if (mmaShape[0] == 16) {
1377 int64_t kFactor;
1378 Type multiplicandFragType;
1379 switch (*getMultiplicandAPtxType()) {
1380 case MMATypes::tf32:
1381 kFactor = 4;
1382 multiplicandFragType = i32Ty;
1383 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1384 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1385 // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
1386 allowedShapes.push_back({16, 8, 8});
1387 allowedShapes.push_back({16, 8, 16});
1388 break;
1389 case MMATypes::bf16:
1390 kFactor = 8;
1391 multiplicandFragType = i32Ty;
1392 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1393 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1394 // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
1395 allowedShapes.push_back({16, 8, 16});
1396 allowedShapes.push_back({16, 8, 32});
1397 break;
1398 case MMATypes::f16:
1399 kFactor = 8;
1400 multiplicandFragType = f16x2Ty;
1401 expectedResult.push_back(f16x2x2StructTy);
1402 expectedResult.push_back(f32x4StructTy);
1403 // Sparse MMA supports m16n8k16 and m16n8k32 for f16
1404 allowedShapes.push_back({16, 8, 16});
1405 allowedShapes.push_back({16, 8, 32});
1406 break;
1407 case MMATypes::s4:
1408 case MMATypes::u4:
1409 kFactor = 32;
1410 // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
1411 allowedShapes.push_back({16, 8, 64});
1412 allowedShapes.push_back({16, 8, 128});
1413 break;
1414 case MMATypes::s8:
1415 case MMATypes::u8:
1416 kFactor = 16;
1417 // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
1418 allowedShapes.push_back({16, 8, 32});
1419 allowedShapes.push_back({16, 8, 64});
1420 break;
1421 case MMATypes::e4m3:
1422 case MMATypes::e5m2:
1423 case MMATypes::e3m2:
1424 case MMATypes::e2m3:
1425 case MMATypes::e2m1:
1426 kFactor = 32;
1427 multiplicandFragType = i32Ty;
1428 expectedResult.push_back(f16x2x2StructTy);
1429 expectedResult.push_back(f32x4StructTy);
1430 // Sparse MMA supports m16n8k64 for FP8 types
1431 allowedShapes.push_back({16, 8, 64});
1432 break;
1433 default:
1434 return emitError("invalid shape or multiplicand type: ")
1435 << getMultiplicandAPtxType().value();
1436 }
1437
1438 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1439 expectedResult.push_back(s32x4StructTy);
1440 expectedC.emplace_back(4, i32Ty);
1441 multiplicandFragType = i32Ty;
1442 } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1443 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1444 // FP8 types
1445 expectedC.emplace_back(2, f16x2Ty);
1446 expectedC.emplace_back(4, f32Ty);
1447 } else {
1448 expectedC.emplace_back(2, f16x2Ty);
1449 expectedC.emplace_back(4, f32Ty);
1450 }
1451
1452 // For sparse MMA, A operand is compressed (2:4 sparsity means half the
1453 // elements)
1454 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1455 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1456 expectedA.emplace_back(unitA, multiplicandFragType);
1457 expectedB.emplace_back(unitB, multiplicandFragType);
1458
1459 if (resultPtxType() != accumPtxType())
1460 return emitOpError("ctype does not match dtype");
1461 }
1462
1463 // In the M=8 case, there is only 1 possible case per data type.
1464 if (mmaShape[0] == 8) {
1465 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1466 expectedA.emplace_back(2, f16x2Ty);
1467 expectedB.emplace_back(2, f16x2Ty);
1468 expectedResult.push_back(f16x2x4StructTy);
1469 expectedResult.push_back(f32x8StructTy);
1470 expectedC.emplace_back(4, f16x2Ty);
1471 expectedC.emplace_back(8, f32Ty);
1472 allowedShapes.push_back({8, 8, 4});
1473 }
1474 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1475 Type f64Ty = Float64Type::get(context);
1476 expectedA.emplace_back(1, f64Ty);
1477 expectedB.emplace_back(1, f64Ty);
1478 expectedC.emplace_back(2, f64Ty);
1479 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1480 context, SmallVector<Type>(2, f64Ty)));
1481 allowedShapes.push_back({8, 8, 4});
1482 }
1483 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1484 expectedA.push_back({i32Ty});
1485 expectedB.push_back({i32Ty});
1486 expectedC.push_back({i32Ty, i32Ty});
1487 expectedResult.push_back(s32x2StructTy);
1488 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1489 allowedShapes.push_back({8, 8, 32});
1490 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1491 allowedShapes.push_back({8, 8, 16});
1492 }
1493 }
1494
1495 std::string errorMessage;
1496 llvm::raw_string_ostream errorStream(errorMessage);
1497
1498 // Check that we matched an existing shape/dtype combination.
1499 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1500 !llvm::is_contained(allowedShapes, mmaShape)) {
1501 errorStream << "unimplemented variant for MMA shape <";
1502 llvm::interleaveComma(mmaShape, errorStream);
1503 errorStream << ">";
1504 return emitOpError(errorMessage);
1505 }
1506
1507 // Verify the operand types for segments of A, B, and C operands.
1508 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1509 for (const auto &iter : llvm::enumerate(
1510 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1511 auto spec = this->getODSOperandIndexAndLength(iter.index());
1512 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1513 operand_type_begin() + spec.first +
1514 spec.second);
1515 bool match = llvm::is_contained(iter.value(), operandTySeg);
1516
1517 if (!match) {
1518 errorStream << "Could not match types for the "
1519 << operandNames[iter.index()]
1520 << " operands; expected one of ";
1521 for (const auto &x : iter.value()) {
1522 errorStream << x.size() << "x" << x[0] << " ";
1523 }
1524 errorStream << "but got ";
1525 llvm::interleaveComma(operandTySeg, errorStream);
1526 return emitOpError(errorMessage);
1527 }
1528 }
1529
1530 // Check the result type
1531 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1532 return expectedResultType == getResult().getType();
1533 })) {
1534 errorStream
1535 << "Could not match allowed types for the result; expected one of ";
1536 llvm::interleaveComma(expectedResult, errorStream);
1537 errorStream << " but got " << getResult().getType();
1538 return emitOpError(errorMessage);
1539 }
1540
1541 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1542 // attribute.
1543 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1544 isInt8PtxType(*getMultiplicandAPtxType())) {
1545 if (!getIntOverflowBehavior())
1546 return emitOpError("op requires " +
1547 getIntOverflowBehaviorAttrName().strref() +
1548 " attribute");
1549 }
1550
1551 // Validate sparse metadata type (should be i32)
1552 if (!getSparseMetadata().getType().isInteger(32)) {
1553 return emitOpError() << "sparse metadata must be i32 type";
1554 }
1555
1556 // Validate sparsity selector type (should be i32)
1557 if (!getSparsitySelector().getType().isInteger(32)) {
1558 return emitOpError() << "sparsity selector must be i32 type";
1559 }
1560
1561 return success();
1562}
1563
1564//===----------------------------------------------------------------------===//
1565// MMA Block Scale Operations - Shared Helpers
1566//===----------------------------------------------------------------------===//
1567
1568namespace {
1569// Shared structure for MMA operand fragments (A, B, C)
1570struct MMAOperandFragment {
1571 StringRef operandName;
1572 StringRef ptxTypeAttr;
1573 SmallVector<Value, 4> regs;
1574 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1575 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1576};
1577} // namespace
1578
1579// Helper to print operand list in the format: name[operands]
1580static void printOperandList(OpAsmPrinter &p, StringRef name,
1581 ArrayRef<Value> operands) {
1582 p << " " << name << "[";
1583 p.printOperands(operands);
1584 p << "]";
1585}
1586
1587// Helper to parse operand list in the format: name[operands]
1588static LogicalResult
1589parseMmaOperand(OpAsmParser &parser, StringRef operandName,
1591 if (parser.parseKeyword(operandName).failed())
1592 return failure();
1594 .failed())
1595 return failure();
1596 return success();
1597}
1598
1599// Helper to process operand fragments and determine which attributes can be
1600// inferred
1601template <typename Op>
1602static void
1603processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
1604 SmallVectorImpl<Type> &regTypes,
1605 SmallVectorImpl<StringRef> &ignoreAttrNames) {
1606 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1607 auto &frag = frags[fragIdx];
1608 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1609 for (auto operandIdx = varOperandSpec.first;
1610 operandIdx < varOperandSpec.first + varOperandSpec.second;
1611 operandIdx++) {
1612 frag.regs.push_back(op.getOperand(operandIdx));
1613 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1614 regTypes.push_back(op.getOperand(operandIdx).getType());
1615 }
1616 }
1617 if (fragIdx < 2) {
1618 regTypes.push_back(frag.regs[0].getType());
1619 }
1620 std::optional<MMATypes> inferredType =
1621 MmaOp::inferOperandMMAType(regTypes.back(),
1622 /*isAccumulator=*/fragIdx >= 2);
1623 if (inferredType)
1624 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1625 }
1626}
1627
1628// Helper to parse type signature: (A_type, B_type, C_type)
1629static LogicalResult
1631 SmallVectorImpl<Type> &operandTypes) {
1632 if (parser.parseColon().failed() || parser.parseLParen().failed())
1633 return failure();
1634
1635 auto typeParser = [&]() {
1636 Type ty;
1637 if (parser.parseType(ty).failed())
1638 return failure();
1639 operandTypes.push_back(ty);
1640 return success();
1641 };
1642 if (parser.parseCommaSeparatedList(typeParser))
1643 return failure();
1644
1645 if (operandTypes.size() != 3)
1646 return parser.emitError(parser.getCurrentLocation(),
1647 "expected exactly 3 types");
1648
1649 return parser.parseRParen();
1650}
1651
1652// Helper to infer and set multiplicand PTX type attributes
1653static void
1655 const SmallVectorImpl<Type> &operandTypes) {
1656 if (!attrs.get("multiplicandAPtxType")) {
1657 if (auto inferredType =
1658 MmaOp::inferOperandMMAType(operandTypes[0], false)) {
1659 attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1660 }
1661 }
1662 if (!attrs.get("multiplicandBPtxType")) {
1663 if (auto inferredType =
1664 MmaOp::inferOperandMMAType(operandTypes[1], false)) {
1665 attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1666 }
1667 }
1668}
1669
1670// Helper to add common block scale properties
1671template <typename OpType>
1674 ScaleVecSize scaleVecSize,
1675 BlockScaleFormat blockScaleFormat,
1676 MMABlockScaleKind kind) {
1677 MLIRContext *ctx = builder.getContext();
1678 auto &properties = result.getOrAddProperties<typename OpType::Properties>();
1679 properties.setShape(
1680 builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1681 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1682 properties.setBlockScaleFormat(
1683 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1684 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1685}
1686
1687// Helper to infer and add multiplicand PTX types to builder
1690 ValueRange operandB,
1691 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1692 if (multiplicandPtxTypes) {
1693 result.addAttribute("multiplicandAPtxType",
1694 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1695 result.addAttribute("multiplicandBPtxType",
1696 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1697 } else {
1698 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1699 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1700 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1701 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1702 }
1703}
1704
1705// Template helper for common accumPtxType/resultPtxType implementation
1706template <typename OpTy>
1707static MMATypes inferPtxTypeFromResult(OpTy op) {
1708 return *MmaOp::inferOperandMMAType(
1709 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1710 /*isAccumulator=*/true);
1711}
1712
1713//===----------------------------------------------------------------------===//
1714// MmaBlockScaleOp
1715//===----------------------------------------------------------------------===//
1716
1717void MmaBlockScaleOp::print(OpAsmPrinter &p) {
1718 SmallVector<Type, 4> regTypes;
1719 std::array<MMAOperandFragment, 3> frags{
1720 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1721 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1722 MMAOperandFragment("C", "")};
1723 SmallVector<StringRef, 4> ignoreAttrNames{
1724 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1725
1726 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1727
1728 // Print A, B, C operands
1729 for (const auto &frag : frags)
1730 printOperandList(p, frag.operandName, frag.regs);
1731
1732 // Print scale operands
1733 printOperandList(p, "scaleA",
1734 {getScaleAData(), getByteIdA(), getThreadIdA()});
1735 printOperandList(p, "scaleB",
1736 {getScaleBData(), getByteIdB(), getThreadIdB()});
1737
1738 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1739
1740 // Print type signature
1741 p << " : (";
1742 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1743 frags[1].regs[0].getType(),
1744 frags[2].regs[0].getType()},
1745 p);
1746 p << ")";
1747 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1748}
1749
1750ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
1752 struct LocalOperandFragment {
1753 std::optional<MMATypes> elemtype;
1754 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1755 };
1756
1757 Builder &builder = parser.getBuilder();
1758 std::array<LocalOperandFragment, 3> frags;
1759 NamedAttrList namedAttributes;
1760
1761 // Parse A[...] B[...] C[...]
1762 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1763 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1764 parseMmaOperand(parser, "C", frags[2].regs).failed())
1765 return failure();
1766
1767 // Parse scale operands: scaleA[...] scaleB[...]
1768 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
1769 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
1770 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
1771 return failure();
1772
1773 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1774 return failure();
1775
1776 // Parse type signature
1777 SmallVector<Type, 3> operandTypes;
1778 if (parseMmaTypeSignature(parser, operandTypes).failed())
1779 return failure();
1780
1781 // Parse result type
1782 SmallVector<Type, 1> resultTypes;
1783 if (parser.parseArrowTypeList(resultTypes).failed())
1784 return failure();
1785
1786 // Infer element types and resolve operands
1787 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
1788 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1789 /*isAccumulator=*/idx >= 2);
1790 if (parser
1791 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
1792 result.operands)
1793 .failed())
1794 return failure();
1795 }
1796
1797 // Resolve scale operands
1798 SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
1799 builder.getI16Type()};
1800 if (parser
1801 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
1802 result.operands)
1803 .failed() ||
1804 parser
1805 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
1806 result.operands)
1807 .failed())
1808 return failure();
1809
1810 // Add attributes
1811 result.addAttributes(namedAttributes);
1812 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
1813 operandTypes);
1814
1815 result.addTypes(resultTypes);
1816 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1817 builder.getDenseI32ArrayAttr({
1818 static_cast<int32_t>(frags[0].regs.size()),
1819 static_cast<int32_t>(frags[1].regs.size()),
1820 static_cast<int32_t>(frags[2].regs.size()),
1821 1, // scaleAData
1822 1, // byteIdA
1823 1, // threadIdA
1824 1, // scaleBData
1825 1, // byteIdB
1826 1 // threadIdB
1827 }));
1828 return success();
1829}
1830
1831void MmaBlockScaleOp::build(
1832 OpBuilder &builder, OperationState &result, Type resultType,
1833 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1834 Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
1835 Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
1836 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1837 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1838 MMABlockScaleKind kind) {
1839 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1840
1842 blockScaleFormat, kind);
1843
1844 result.addOperands(operandA);
1845 result.addOperands(operandB);
1846 result.addOperands(operandC);
1847 result.addOperands(
1848 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1849
1850 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
1851 multiplicandPtxTypes);
1852
1853 result.addTypes(resultType);
1854 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1855 builder.getDenseI32ArrayAttr({
1856 static_cast<int32_t>(operandA.size()),
1857 static_cast<int32_t>(operandB.size()),
1858 static_cast<int32_t>(operandC.size()),
1859 1, // scaleAData
1860 1, // byteIdA
1861 1, // threadIdA
1862 1, // scaleBData
1863 1, // byteIdB
1864 1 // threadIdB
1865 }));
1866}
1867
1868NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
1869 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1870 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1871
1873 // Add A, B, C operands
1874 for (Value operand : curOp.getOperandA())
1875 args.push_back(mt.lookupValue(operand));
1876 for (Value operand : curOp.getOperandB())
1877 args.push_back(mt.lookupValue(operand));
1878 for (Value operand : curOp.getOperandC())
1879 args.push_back(mt.lookupValue(operand));
1880
1881 // Add scale operands
1882 args.push_back(mt.lookupValue(curOp.getScaleAData()));
1883 args.push_back(mt.lookupValue(curOp.getByteIdA()));
1884 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
1885 args.push_back(mt.lookupValue(curOp.getScaleBData()));
1886 args.push_back(mt.lookupValue(curOp.getByteIdB()));
1887 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
1888
1889 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1890 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1891 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1892 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
1893 curOp.getBlockScaleFormat(), curOp.getKind());
1894
1895 return {intId, args};
1896}
1897
1898LogicalResult MmaBlockScaleOp::verify() {
1899 LogicalResult result = success();
1900 int m = getShape().getM();
1901 int n = getShape().getN();
1902 int k = getShape().getK();
1903
1904 if (m == 16 && n == 8 && k == 64) {
1905 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1906 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1908 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1909 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1910 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1912 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1913 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1915 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1916 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1917 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1918 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1919 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1920 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
1921 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
1922 "attributes for mma.m16n8k64.mxf4nvf4");
1923 } else {
1924 result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
1925 }
1926 } else if (m == 16 && n == 8 && k == 32) {
1927 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
1928 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
1929 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
1930 result =
1931 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
1932 "attributes for mma.m16n8k32");
1933 } else {
1934 result = emitOpError("unsupported Geom for mma with block scaling");
1935 }
1936 return result;
1937}
1938
1939//===----------------------------------------------------------------------===//
1940// MmaSpBlockScaleOp
1941//===----------------------------------------------------------------------===//
1942
1943void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
1944 SmallVector<Type, 4> regTypes;
1945 std::array<MMAOperandFragment, 3> frags{
1946 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1947 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1948 MMAOperandFragment("C", "")};
1949 SmallVector<StringRef, 4> ignoreAttrNames{
1950 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
1951
1952 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1953
1954 // Print A, B, C operands
1955 for (const auto &frag : frags)
1956 printOperandList(p, frag.operandName, frag.regs);
1957
1958 // Print sparse-specific operands
1959 printOperandList(p, "sparseMetadata", {getSparseMetadata()});
1960 printOperandList(p, "selector", {getSparsitySelector()});
1961
1962 // Print scale operands
1963 printOperandList(p, "scaleA",
1964 {getScaleAData(), getByteIdA(), getThreadIdA()});
1965 printOperandList(p, "scaleB",
1966 {getScaleBData(), getByteIdB(), getThreadIdB()});
1967
1968 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1969
1970 // Print type signature
1971 p << " : (";
1972 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1973 frags[1].regs[0].getType(),
1974 frags[2].regs[0].getType()},
1975 p);
1976 p << ")";
1977 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1978}
1979
1980ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
1982 struct LocalOperandFragment {
1983 std::optional<MMATypes> elemtype;
1984 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1985 };
1986
1987 Builder &builder = parser.getBuilder();
1988 std::array<LocalOperandFragment, 3> frags;
1989 NamedAttrList namedAttributes;
1990
1991 // Parse A[...] B[...] C[...]
1992 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1993 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1994 parseMmaOperand(parser, "C", frags[2].regs).failed())
1995 return failure();
1996
1997 // Parse sparse-specific operands
1999 selectorOperands;
2000 if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
2001 parseMmaOperand(parser, "selector", selectorOperands).failed())
2002 return failure();
2003
2004 // Parse scale operands
2005 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
2006 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
2007 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
2008 return failure();
2009
2010 if (parser.parseOptionalAttrDict(namedAttributes).failed())
2011 return failure();
2012
2013 // Parse type signature
2014 SmallVector<Type, 3> operandTypes;
2015 if (parseMmaTypeSignature(parser, operandTypes).failed())
2016 return failure();
2017
2018 // Parse result type
2019 SmallVector<Type, 1> resultTypes;
2020 if (parser.parseArrowTypeList(resultTypes).failed())
2021 return failure();
2022
2023 // Infer element types and resolve operands
2024 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
2025 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2026 /*isAccumulator=*/idx >= 2);
2027 if (parser
2028 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
2029 result.operands)
2030 .failed())
2031 return failure();
2032 }
2033
2034 // Resolve sparse metadata and selector
2035 Type i32Type = builder.getI32Type();
2036 if (parser
2037 .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
2038 result.operands)
2039 .failed() ||
2040 parser
2041 .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
2042 result.operands)
2043 .failed())
2044 return failure();
2045
2046 // Resolve scale operands
2047 SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
2048 builder.getI16Type()};
2049 if (parser
2050 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
2051 result.operands)
2052 .failed() ||
2053 parser
2054 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
2055 result.operands)
2056 .failed())
2057 return failure();
2058
2059 // Add attributes
2060 result.addAttributes(namedAttributes);
2061 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
2062 operandTypes);
2063
2064 // orderedMetadata is mandatory
2065 if (!result.attributes.get("orderedMetadata"))
2066 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2067
2068 result.addTypes(resultTypes);
2069 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2070 builder.getDenseI32ArrayAttr({
2071 static_cast<int32_t>(frags[0].regs.size()),
2072 static_cast<int32_t>(frags[1].regs.size()),
2073 static_cast<int32_t>(frags[2].regs.size()),
2074 1, // sparseMetadata
2075 1, // sparsitySelector
2076 1, // scaleAData
2077 1, // byteIdA
2078 1, // threadIdA
2079 1, // scaleBData
2080 1, // byteIdB
2081 1 // threadIdB
2082 }));
2083 return success();
2084}
2085
2086void MmaSpBlockScaleOp::build(
2087 OpBuilder &builder, OperationState &result, Type resultType,
2088 ValueRange operandA, ValueRange operandB, ValueRange operandC,
2089 Value sparseMetadata, Value sparsitySelector, Value scaleAData,
2090 Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
2091 Value threadIdB, ArrayRef<int64_t> shape,
2092 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2093 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2094 MMABlockScaleKind kind) {
2095 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
2096
2098 builder, result, shape, scaleVecSize, blockScaleFormat, kind);
2099 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2100
2101 result.addOperands(operandA);
2102 result.addOperands(operandB);
2103 result.addOperands(operandC);
2104 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2105 threadIdA, scaleBData, byteIdB, threadIdB});
2106
2107 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
2108 multiplicandPtxTypes);
2109
2110 result.addTypes(resultType);
2111 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2112 builder.getDenseI32ArrayAttr({
2113 static_cast<int32_t>(operandA.size()),
2114 static_cast<int32_t>(operandB.size()),
2115 static_cast<int32_t>(operandC.size()),
2116 1, // sparseMetadata
2117 1, // sparsitySelector
2118 1, // scaleAData
2119 1, // byteIdA
2120 1, // threadIdA
2121 1, // scaleBData
2122 1, // byteIdB
2123 1 // threadIdB
2124 }));
2125}
2126
2127NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
2128 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2129 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2130
2132 // Add A, B, C operands
2133 for (Value operand : curOp.getOperandA())
2134 args.push_back(mt.lookupValue(operand));
2135 for (Value operand : curOp.getOperandB())
2136 args.push_back(mt.lookupValue(operand));
2137 for (Value operand : curOp.getOperandC())
2138 args.push_back(mt.lookupValue(operand));
2139
2140 // Add sparse metadata and selector
2141 args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
2142 args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
2143
2144 // Add scale operands
2145 args.push_back(mt.lookupValue(curOp.getScaleAData()));
2146 args.push_back(mt.lookupValue(curOp.getByteIdA()));
2147 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
2148 args.push_back(mt.lookupValue(curOp.getScaleBData()));
2149 args.push_back(mt.lookupValue(curOp.getByteIdB()));
2150 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
2151
2152 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2153 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2154 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2155 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
2156 curOp.getBlockScaleFormat(), curOp.getKind());
2157
2158 return {intId, args};
2159}
2160
2161LogicalResult MmaSpBlockScaleOp::verify() {
2162 // Check that orderedMetadata is present
2163 if (!getOrderedMetadata()) {
2164 return emitOpError("'orderedMetadata' attribute is mandatory");
2165 }
2166
2167 LogicalResult result = success();
2168 int m = getShape().getM();
2169 int n = getShape().getN();
2170 int k = getShape().getK();
2171
2172 if (m == 16 && n == 8 && k == 128) {
2173 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2174 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2176 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2177 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2178 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2180 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2181 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2183 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2184 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2185 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2186 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2187 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2188 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
2189 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
2190 "attributes for mma.m16n8k128.mxf4nvf4");
2191 } else {
2192 result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
2193 }
2194 } else if (m == 16 && n == 8 && k == 64) {
2195 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2196 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2197 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2198 result =
2199 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
2200 "attributes for mma.m16n8k64");
2201 } else {
2202 result = emitOpError("unsupported Geom for sparse mma with block scaling");
2203 }
2204 return result;
2205}
2206
2207LogicalResult ShflOp::verify() {
2208 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2209
2210 auto verifyTypeError = [&](Twine desc, Type expectedType,
2211 Type actualType) -> LogicalResult {
2212 return emitOpError("expected " + desc + " to be of type ")
2213 << expectedType << " but got " << actualType << " instead";
2214 };
2215
2216 if (returnStructType) {
2217 if (!getReturnValueAndIsValid())
2218 return emitOpError("\"return_value_and_is_valid\" attribute must be "
2219 "specified when the return type is a struct type");
2220
2221 if (returnStructType.getBody().size() != 2)
2222 return emitOpError("expected return type to be a two-element struct");
2223
2224 llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
2225 auto resultType = returnStruct[0];
2226 if (resultType != getVal().getType())
2227 return verifyTypeError("first element in the returned struct",
2228 getVal().getType(), resultType);
2229
2230 auto predicateType = returnStruct[1];
2231 if (!predicateType.isInteger(1))
2232 return verifyTypeError("second element in the returned struct",
2233 mlir::IntegerType::get(getContext(), 1),
2234 predicateType);
2235 } else {
2236 if (getReturnValueAndIsValid())
2237 return emitOpError("expected return type to be a two-element struct");
2238
2239 if (getType() != getVal().getType())
2240 return verifyTypeError("return type", getVal().getType(), getType());
2241 }
2242 return success();
2243}
2244
2245std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
2246 NVVM::MMAFrag frag, int nRow,
2247 int nCol,
2248 MLIRContext *context) {
2249 unsigned numberElements = 0;
2250 Type elementType;
2251 OpBuilder builder(context);
2252 Type f16x2 = VectorType::get(2, builder.getF16Type());
2253 if (type == NVVM::MMATypes::f16) {
2254 elementType = f16x2;
2255 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2256 numberElements = 8;
2257 else
2258 numberElements = 4;
2259 } else if (type == NVVM::MMATypes::f32) {
2260 elementType = builder.getF32Type();
2261 numberElements = 8;
2262 } else if (type == NVVM::MMATypes::f64) {
2263 elementType = builder.getF64Type();
2264 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2265 numberElements = 1;
2266 else
2267 numberElements = 2;
2268 } else if (type == NVVM::MMATypes::tf32) {
2269 elementType = builder.getI32Type();
2270 numberElements = 4;
2271 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2272 elementType = builder.getI32Type();
2273 int parallelSize = 0;
2274 if (frag == NVVM::MMAFrag::a)
2275 parallelSize = nRow;
2276 if (frag == NVVM::MMAFrag::b)
2277 parallelSize = nCol;
2278
2279 // m == 16 && n == 16 && k == 16
2280 if (parallelSize == 16)
2281 numberElements = 2;
2282 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
2283 else if (parallelSize == 8)
2284 numberElements = 1;
2285 else if (parallelSize == 32)
2286 numberElements = 4;
2287 } else if (type == NVVM::MMATypes::s32) {
2288 elementType = builder.getI32Type();
2289 numberElements = 8;
2290 }
2291 assert(numberElements != 0 && elementType != nullptr);
2292 return std::make_pair(elementType, numberElements);
2293}
2294
2295static std::pair<mlir::Type, unsigned>
2296inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
2297 int k, MLIRContext *context) {
2298 int nRow, nCol;
2299 if (frag == NVVM::MMAFrag::a) {
2300 nRow = m;
2301 nCol = k;
2302 } else if (frag == NVVM::MMAFrag::b) {
2303 nRow = k;
2304 nCol = n;
2305 } else {
2306 nRow = m;
2307 nCol = n;
2308 }
2309 assert(nRow && nCol);
2310 return inferMMAType(type, frag, nRow, nCol, context);
2311}
2312
2313LogicalResult NVVM::WMMALoadOp::verify() {
2314 unsigned addressSpace =
2315 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2316 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2317 addressSpace != NVVMMemorySpace::Shared)
2318 return emitOpError("expected source pointer in memory "
2319 "space 0, 1, 3");
2320
2321 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2322 getEltype(), getFrag()) == 0)
2323 return emitOpError() << "invalid attribute combination";
2324 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2325 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
2326 // Special case for f64 fragments
2327 Type f64Ty = Float64Type::get(getContext());
2328 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2329 if (getType() != f64Ty)
2330 return emitOpError("expected destination type to be f64");
2331 return success();
2332 }
2333 // Everything else is a struct
2334 Type dstType = LLVM::LLVMStructType::getLiteral(
2335 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
2336 if (getType() != dstType)
2337 return emitOpError("expected destination type is a structure of ")
2338 << typeInfo.second << " elements of type " << typeInfo.first;
2339 return success();
2340}
2341
2342LogicalResult NVVM::WMMAStoreOp::verify() {
2343 unsigned addressSpace =
2344 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2345 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2346 addressSpace != NVVMMemorySpace::Shared)
2347 return emitOpError("expected operands to be a source pointer in memory "
2348 "space 0, 1, 3");
2349
2350 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2351 getEltype()) == 0)
2352 return emitOpError() << "invalid attribute combination";
2353 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2354 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2355 if (getArgs().size() != typeInfo.second)
2356 return emitOpError() << "expected " << typeInfo.second << " data operands";
2357 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
2358 return operands.getType() != typeInfo.first;
2359 }))
2360 return emitOpError() << "expected data operands of type " << typeInfo.first;
2361 return success();
2362}
2363
2364LogicalResult NVVM::WMMAMmaOp::verify() {
2365 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
2366 getLayoutB(), getEltypeA(),
2367 getEltypeB()) == 0)
2368 return emitOpError() << "invalid attribute combination";
2369 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
2370 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
2371 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
2372 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
2373 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
2374 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2375 SmallVector<Type, 32> arguments;
2376 arguments.append(typeInfoA.second, typeInfoA.first);
2377 arguments.append(typeInfoB.second, typeInfoB.first);
2378 arguments.append(typeInfoC.second, typeInfoC.first);
2379 unsigned numArgs = arguments.size();
2380 if (getArgs().size() != numArgs)
2381 return emitOpError() << "expected " << numArgs << " arguments";
2382 for (unsigned i = 0; i < numArgs; i++) {
2383 if (getArgs()[i].getType() != arguments[i])
2384 return emitOpError() << "expected argument " << i << " to be of type "
2385 << arguments[i];
2386 }
2387 Type dstType = LLVM::LLVMStructType::getLiteral(
2388 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
2389 if (getType() != dstType)
2390 return emitOpError("expected destination type is a structure of ")
2391 << typeInfoC.second << " elements of type " << typeInfoC.first;
2392 return success();
2393}
2394
2395LogicalResult NVVM::LdMatrixOp::verify() {
2396 uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
2397 if (m == 8 && n == 8) {
2398 if (num != 1 && num != 2 && num != 4) {
2399 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
2400 "matrix");
2401 }
2402 if (getEltType() != LdStMatrixEltType::B16) {
2403 return emitOpError("expected element type to be b16 for 8x8 matrix");
2404 }
2405 } else if (m == 8 && n == 16) {
2406 if (num != 1 && num != 2 && num != 4) {
2407 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
2408 "matrix");
2409 }
2410 if (getLayout() != MMALayout::row) {
2411 return emitOpError("expected layout to be row for 8x16 matrix");
2412 }
2413 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2414 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2415 return emitOpError("expected element type to be b8x16.b4x16_p64 or "
2416 "b8x16.b6x16_p32 for 8x16 matrix");
2417 }
2418 } else if (m == 16 && n == 16) {
2419 if (num != 1 && num != 2) {
2420 return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
2421 "matrix");
2422 }
2423 if (getLayout() != MMALayout::col) {
2424 return emitOpError("expected layout to be col for 16x16 matrix");
2425 }
2426 if (getEltType() != LdStMatrixEltType::B8 &&
2427 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2428 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2429 return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
2430 "b8x16.b6x16_p32 for 16x16 matrix");
2431 }
2432 } else {
2433 return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
2434 }
2435
2436 Type i32 = IntegerType::get(getContext(), 32);
2437 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2438 if (numElements == 1 && getType() != i32)
2439 return emitOpError("expected destination type is i32");
2440 if (numElements == 2 || numElements == 4) {
2441 Type dstType = LLVM::LLVMStructType::getLiteral(
2442 getContext(), SmallVector<Type>(numElements, i32));
2443 if (getType() != dstType)
2444 return emitOpError("expected destination type is a structure of ")
2445 << numElements << " elements of type i32";
2446 }
2447
2448 return success();
2449}
2450
2451LogicalResult NVVM::StMatrixOp::verify() {
2452 int numMatrix = getSources().size();
2453 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2454 return emitOpError("expected num attribute to be 1, 2 or 4");
2455
2456 int m = getShape().getM(), n = getShape().getN();
2457 if (m == 8 && n == 8) {
2458 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2459 return emitOpError("expected element type to be B16 for 8x8 matrix");
2460 }
2461 } else if (m == 16 && n == 8) {
2462 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2463 return emitOpError("expected element type to be B8 for 16x8 matrix");
2464 }
2465 if (getLayout() != NVVM::MMALayout::col) {
2466 return emitOpError("expected layout to be col for 16x8 matrix");
2467 }
2468 } else {
2469 return emitOpError("expected shape to be 8x8 or 16x8");
2470 }
2471
2472 return success();
2473}
2474
2475static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
2476 if (typeA == NVVM::WGMMATypes::tf32)
2477 return 8;
2478 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2479 return 16;
2480 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2481 return 32;
2482 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2483 return 32;
2484 if (typeA == NVVM::WGMMATypes::b1)
2485 return 256;
2486 return failure();
2487}
2488
2489static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
2490 NVVM::WGMMATypes typeA,
2491 NVVM::WGMMATypes typeB) {
2492 switch (typeA) {
2493 case NVVM::WGMMATypes::f16:
2494 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2495 typeB == NVVM::WGMMATypes::f16)
2496 return success();
2497 break;
2498 case NVVM::WGMMATypes::tf32:
2499 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2500 return success();
2501 break;
2502 case NVVM::WGMMATypes::u8:
2503 case NVVM::WGMMATypes::s8:
2504 if (typeD == NVVM::WGMMATypes::s32 &&
2505 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2506 return success();
2507 break;
2508 case NVVM::WGMMATypes::b1:
2509 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2510 return success();
2511 break;
2512 case NVVM::WGMMATypes::bf16:
2513 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2514 typeB == NVVM::WGMMATypes::bf16)
2515 return success();
2516 break;
2517 case NVVM::WGMMATypes::e4m3:
2518 case NVVM::WGMMATypes::e5m2:
2519 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2520 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2521 return success();
2522 break;
2523 case WGMMATypes::f32:
2524 case WGMMATypes::s32:
2525 llvm_unreachable("unsupported input types");
2526 break;
2527 }
2528 return failure();
2529}
2530
2531static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
2532 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
2533 72, 80, 88, 96, 104, 112, 120, 128,
2534 136, 144, 152, 160, 168, 176, 184, 192,
2535 200, 208, 216, 224, 232, 240, 248, 256};
2536 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
2537 80, 96, 112, 128, 144, 160,
2538 176, 192, 208, 224, 240, 256};
2539 switch (typeA) {
2540 case WGMMATypes::f16:
2541 case WGMMATypes::tf32:
2542 case WGMMATypes::bf16:
2543 case WGMMATypes::e4m3:
2544 case WGMMATypes::e5m2:
2545 if (llvm::is_contained(allowedN, sizeN))
2546 return success();
2547 break;
2548 case WGMMATypes::u8:
2549 case WGMMATypes::s8:
2550 case WGMMATypes::b1:
2551 if (llvm::is_contained(allowedNshort, sizeN))
2552 return success();
2553 break;
2554 case WGMMATypes::f32:
2555 case WGMMATypes::s32:
2556 llvm_unreachable("unsupported input types");
2557 break;
2558 }
2559 return failure();
2560}
2561
2562LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2563 Value outValue = getResults();
2564 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
2565 if (!stype)
2566 return emitOpError() << "expected results to be struct";
2567 int outputSize = stype.getBody().size();
2568 WGMMATypes typeD = getTypeD();
2569 WGMMATypes typeA = getTypeA();
2570 WGMMATypes typeB = getTypeB();
2571
2572 for (Type t : stype.getBody()) {
2573 if (t != stype.getBody().front())
2574 return emitOpError()
2575 << "all elements in struct must be same type but there is " << t;
2576 }
2577
2578 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2579 typeD != WGMMATypes::s32) {
2580 return emitOpError() << "does not support the given output type " << typeD;
2581 }
2582 if (typeD == WGMMATypes::s32 &&
2583 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2584 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
2585 }
2586
2587 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
2588 return emitOpError() << typeD << " += " << typeA << " * " << typeB
2589 << ", it is not supported.";
2590 }
2591
2592 // Check M
2593 if (getShape().getM() != 64)
2594 return emitOpError() << "shape 'm' must be 64";
2595
2596 // Check K
2597 FailureOr<int> allowedK = getAllowedSizeK(typeA);
2598 if (failed(allowedK) || allowedK.value() != getShape().getK())
2599 return emitOpError() << "shape 'k' must be " << allowedK.value()
2600 << " for input type " << typeA;
2601
2602 // Check N
2603 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
2604 return emitOpError() << "has input type " << typeA << " n is set to "
2605 << getShape().getN() << ", it is not supported.";
2606 }
2607
2608 // Check transpose (only available for f16/bf16)
2609 // Matrices A should be stored in row-major and B in column-major.
2610 // Only f16/bf16 matrices can be stored in either column-major or row-major
2611 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
2612 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2613 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2614 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2615 return emitOpError()
2616 << "given layouts layout_a = " << getLayoutA()
2617 << " and layout_b = " << getLayoutB() << " for input types " << typeA
2618 << " and " << typeB
2619 << " requires transpose. However, this is only supported for: "
2620 << MMATypes::f16 << " and " << MMATypes::bf16;
2621 }
2622
2623 // Check result registers
2624 int expectedOutput = 0;
2625 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2626 expectedOutput = getShape().getN() / 2;
2627 if (typeD == WGMMATypes::f16)
2628 expectedOutput = getShape().getN() / 4;
2629 if (outputSize != expectedOutput) {
2630 return emitOpError() << "results " << expectedOutput
2631 << ", however output struct has " << outputSize
2632 << " elements";
2633 }
2634 // Check satfinite (only available for s32 accumulator)
2635 if (typeD != WGMMATypes::s32 &&
2636 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2637 NVVM::MMAIntOverflow::satfinite) {
2638 return emitOpError()
2639 << " `satfinite` can be only used with s32 accumulator, however "
2640 "the current accumulator is "
2641 << typeD;
2642 }
2643
2644 return success();
2645}
2646
2647std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2648
2649 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
2650 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2651
2652 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2653
2654 int expectedOutputRegisters = 0;
2655 if (getTypeD() == WGMMATypes::f16)
2656 expectedOutputRegisters = getShape().getN() / 4;
2657 else
2658 expectedOutputRegisters = getShape().getN() / 2;
2659
2660 std::string ptx;
2661 llvm::raw_string_ostream ss(ptx);
2662
2663 ss << "{\n"
2664 ".reg .pred p;\n"
2665 "setp.ne.b32 p, $"
2666 << ((expectedOutputRegisters * 2) + 2)
2667 << ", 0;\n"
2668 "wgmma.mma_async.sync.aligned.m"
2669 << m << "n" << n << "k" << k << "." << outputTypeName << "." << getTypeA()
2670 << "." << getTypeB();
2671 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2672 NVVM::MMAIntOverflow::satfinite)
2673 ss << ".satfinite";
2674 ss << " {";
2675 int regCnt = 0;
2676 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2677 ss << "$" << regCnt;
2678 if (regCnt != expectedOutputRegisters - 1)
2679 ss << ", ";
2680 }
2681
2682 ss << "},";
2683 // Need to map read/write registers correctly.
2684 regCnt = (regCnt * 2);
2685 ss << " $" << (regCnt) << ","
2686 << " $" << (regCnt + 1) << ","
2687 << " p";
2688 if (getTypeD() != WGMMATypes::s32) {
2689 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
2690 }
2691 // Don't add transpose parameters unless needed.
2692 if (isF16) {
2693 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
2694 }
2695 ss << ";\n"
2696 << "}\n";
2697 return ptx;
2698}
2699
2700bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2701 RewriterBase &rewriter,
2702 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2703 &asmValues) {
2704 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2705 if (getResults())
2706 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
2707 if (getInouts())
2708 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
2709 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
2710 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
2711 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
2713 if (getTypeD() != WGMMATypes::s32) {
2714 asmValues.push_back(
2715 {makeConstantI32(rewriter,
2716 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2718 asmValues.push_back(
2719 {makeConstantI32(rewriter,
2720 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2722 }
2723 if (isF16) {
2724 asmValues.push_back(
2725 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
2727 asmValues.push_back(
2728 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
2730 }
2731 return true; // Has manual mapping
2732}
2733
2734LogicalResult NVVM::FenceSyncRestrictOp::verify() {
2735 if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
2736 getOrder() != NVVM::MemOrderKind::RELEASE)
2737 return emitOpError("only acquire and release semantics are supported");
2738 return success();
2739}
2740
2741LogicalResult NVVM::FenceProxyOp::verify() {
2742 if (getKind() == NVVM::ProxyKind::TENSORMAP)
2743 return emitOpError() << "tensormap proxy is not a supported proxy kind";
2744 if (getKind() == NVVM::ProxyKind::GENERIC)
2745 return emitOpError() << "generic proxy not a supported proxy kind";
2746 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2747 return emitOpError() << "async_shared fence requires space attribute";
2748 }
2749 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2750 return emitOpError() << "only async_shared fence can have space attribute";
2751 }
2752 return success();
2753}
2754
2755LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2756 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2757 return emitOpError("uni-directional proxies only support generic for "
2758 "from_proxy attribute");
2759
2760 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2761 return emitOpError("uni-directional proxies only support tensormap "
2762 "for to_proxy attribute");
2763 return success();
2764}
2765
2766LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2767 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2768 return emitOpError("uni-directional proxies only support generic for "
2769 "from_proxy attribute");
2770
2771 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2772 return emitOpError("uni-directional proxies only support tensormap "
2773 "for to_proxy attribute");
2774 return success();
2775}
2776
2777LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2778 if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
2779 getOrder() != NVVM::MemOrderKind::RELEASE)
2780 return emitOpError("only acquire and release semantics are supported");
2781
2782 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2783 return emitOpError("only generic is support for from_proxy attribute");
2784
2785 if (getToProxy() != NVVM::ProxyKind::async)
2786 return emitOpError("only async is supported for to_proxy attribute");
2787 return success();
2788}
2789
2790LogicalResult NVVM::SetMaxRegisterOp::verify() {
2791 if (getRegCount() % 8)
2792 return emitOpError("new register size must be multiple of 8");
2793 if (getRegCount() < 24 || getRegCount() > 256)
2794 return emitOpError("new register size must be in between 24 to 256");
2795 return success();
2796}
2797
2798LogicalResult NVVM::BarrierOp::verify() {
2799 if (getNumberOfThreads() && !getBarrierId())
2800 return emitOpError(
2801 "barrier id is missing, it should be set between 0 to 15");
2802
2803 if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
2804 return emitOpError("reduction are only available when id is 0");
2805
2806 if ((getReductionOp() && !getReductionPredicate()) ||
2807 (!getReductionOp() && getReductionPredicate()))
2808 return emitOpError("reduction predicate and reduction operation must be "
2809 "specified together");
2810
2811 return success();
2812}
2813
2814LogicalResult NVVM::Tcgen05CpOp::verify() {
2815 auto mc = getMulticast();
2816
2817 using SH = Tcgen05CpShape;
2818 using MC = Tcgen05CpMulticast;
2819 switch (getShape()) {
2820 case SH::SHAPE_128x256b:
2821 case SH::SHAPE_128x128b:
2822 case SH::SHAPE_4x256b:
2823 if (mc != MC::NONE)
2824 return emitError("Invalid multicast type for tcgen05.cp Op");
2825 break;
2826 case SH::SHAPE_64x128b:
2827 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2828 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
2829 "warpx2_02_13 for tcgen05.cp Op");
2830 break;
2831 case SH::SHAPE_32x128b:
2832 if (mc != MC::WARPX4)
2833 return emitError(
2834 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2835 break;
2836 }
2837 return success();
2838}
2839
2840LogicalResult NVVM::MatchSyncOp::verify() {
2841 if (getKind() == NVVM::MatchSyncKind::all) {
2842 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2843 if (!type || type.getBody().size() != 2 ||
2844 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2845 return emitOpError("match.sync 'all' returns a two element struct with "
2846 "first element as i32 and second element as i1");
2847 }
2848 } else {
2849 if (!getType().isInteger(32)) {
2850 return emitOpError("match.sync 'any' returns an i32");
2851 }
2852 }
2853 return success();
2854}
2855
2856LogicalResult NVVM::VoteSyncOp::verify() {
2857 if (getKind() == NVVM::VoteSyncKind::ballot) {
2858 if (!getType().isInteger(32)) {
2859 return emitOpError("vote.sync 'ballot' returns an i32");
2860 }
2861 } else {
2862 if (!getType().isInteger(1)) {
2863 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
2864 }
2865 }
2866 return success();
2867}
2868
2869LogicalResult NVVM::PrefetchOp::verify() {
2870 using MemSpace = NVVM::NVVMMemorySpace;
2871 using CacheLevel = NVVM::PrefetchCacheLevel;
2872
2873 unsigned addressSpace =
2874 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
2875 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2876 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2877
2878 if (getTensormap() && cacheLevel)
2879 return emitOpError("cannot specify both tensormap and cache level");
2880
2881 if (getTensormap()) {
2882 if (addressSpace != MemSpace::Generic &&
2883 addressSpace != MemSpace::Constant) {
2884 return emitOpError(
2885 "prefetch tensormap requires a generic or constant pointer");
2886 }
2887
2888 if (evictPriority) {
2889 return emitOpError(
2890 "prefetch tensormap does not support eviction priority");
2891 }
2892
2893 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2894 return emitOpError(
2895 "in_param_space can only be specified for a generic pointer");
2896 }
2897
2898 } else if (cacheLevel) {
2899 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
2900 addressSpace != MemSpace::Local) {
2901 return emitOpError("prefetch to cache level requires a generic, global, "
2902 "or local pointer");
2903 }
2904
2905 if (getUniform()) {
2906 if (*cacheLevel != CacheLevel::L1) {
2907 return emitOpError(
2908 "unsupported cache level, the only supported uniform "
2909 "cache level is L1");
2910 }
2911
2912 if (addressSpace != MemSpace::Generic) {
2913 return emitOpError(
2914 "prefetch to uniform cache requires a generic pointer");
2915 }
2916 }
2917
2918 if (evictPriority) {
2919 if (*cacheLevel != CacheLevel::L2)
2920 return emitOpError(
2921 "cache eviction priority supported only for cache level L2");
2922
2923 if (addressSpace != MemSpace::Global)
2924 return emitOpError("cache eviction priority requires a global pointer");
2925
2926 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
2927 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
2928 return emitOpError(
2929 "unsupported cache eviction priority, only evict_last and "
2930 "evict_normal are supported");
2931 }
2932
2933 if (getPredicate())
2934 return emitOpError("predicate supported only on prefetch tensormap");
2935
2936 } else {
2937 return emitOpError(
2938 "requires specification of either cache level or tensormap");
2939 }
2940
2941 return success();
2942}
2943
2944LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
2945 switch (getQueryType()) {
2946 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2947 if (!getType().isInteger(1))
2948 return emitOpError("is_canceled query type returns an i1");
2949 break;
2950 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2951 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2952 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2953 if (!getType().isInteger(32)) {
2954 return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
2955 "get_first_cta_id_z query types return an i32");
2956 }
2957 break;
2958 }
2959 return success();
2960}
2961
2962LogicalResult NVVM::ReduxOp::verify() {
2963 mlir::Type reduxType = getType();
2964
2965 if (!reduxType.isF32()) {
2966 if (getAbs())
2967 return emitOpError("abs attribute is supported only for f32 type");
2968 if (getNan())
2969 return emitOpError("nan attribute is supported only for f32 type");
2970 }
2971
2972 NVVM::ReductionKind kind = getKind();
2973 switch (kind) {
2974 case NVVM::ReductionKind::ADD:
2975 case NVVM::ReductionKind::AND:
2976 case NVVM::ReductionKind::OR:
2977 case NVVM::ReductionKind::XOR:
2978 case NVVM::ReductionKind::MAX:
2979 case NVVM::ReductionKind::MIN:
2980 case NVVM::ReductionKind::UMAX:
2981 case NVVM::ReductionKind::UMIN:
2982 if (!reduxType.isInteger(32))
2983 return emitOpError("'")
2984 << kind << "' reduction kind unsupported with " << reduxType
2985 << " type. Only supported type is 'i32'.";
2986 break;
2987 case NVVM::ReductionKind::FMIN:
2988 case NVVM::ReductionKind::FMAX:
2989 if (!reduxType.isF32())
2990 return emitOpError("'")
2991 << kind << "' reduction kind unsupported with " << reduxType
2992 << " type. Only supported type is 'f32'.";
2993 break;
2994 }
2995
2996 return success();
2997}
2998
2999LogicalResult NVVM::TensormapReplaceOp::verify() {
3000 auto ord = getOrd();
3001 Value newVal = getNewValue();
3002 auto newValAttr = getNewValueAttr();
3003 auto fieldName = stringifyEnum(getField());
3004
3005 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
3006 NVVM::TensormapField::GLOBAL_DIM,
3007 NVVM::TensormapField::GLOBAL_STRIDE,
3008 NVVM::TensormapField::ELEMENT_STRIDE},
3009 getField()))
3010 return emitOpError("ordinal is not supported for ")
3011 << fieldName << " field";
3012
3013 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3014 return llvm::Twine("new_value must be specified and must be an " + type +
3015 " for " + llvm::Twine(fieldName) + " field")
3016 .str();
3017 };
3018
3019 auto invalidNewValAttr = [&]() -> std::string {
3020 return (llvm::Twine(
3021 "new_value_attr must be specified and must be a valid ") +
3022 llvm::Twine(fieldName) + " attribute for " + fieldName + " field")
3023 .str();
3024 };
3025
3026 switch (getField()) {
3027 case NVVM::TensormapField::GLOBAL_ADDRESS:
3028 if (!(newVal && newVal.getType().isInteger(64)))
3029 return emitOpError(invalidNewVal("i64"));
3030 break;
3031 case NVVM::TensormapField::RANK:
3032 if (!(newVal && newVal.getType().isInteger(32)))
3033 return emitOpError(invalidNewVal("i32"));
3034 break;
3035 case NVVM::TensormapField::GLOBAL_STRIDE:
3036 if (!ord)
3037 return emitOpError("ordinal is required for global_stride field");
3038 if (!(newVal && newVal.getType().isInteger(64)))
3039 return emitOpError(invalidNewVal("i64"));
3040 break;
3041 case NVVM::TensormapField::BOX_DIM:
3042 case NVVM::TensormapField::GLOBAL_DIM:
3043 case NVVM::TensormapField::ELEMENT_STRIDE:
3044 if (!ord)
3045 return emitOpError("ordinal is required for ")
3046 << stringifyEnum(getField()) << " field";
3047 if (!(newVal && newVal.getType().isInteger(32)))
3048 return emitOpError(invalidNewVal("i32"));
3049 break;
3050 case NVVM::TensormapField::ELEMTYPE:
3051 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3052 return emitOpError(invalidNewValAttr());
3053 break;
3054 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3055 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3056 return emitOpError(invalidNewValAttr());
3057 break;
3058 case NVVM::TensormapField::SWIZZLE_MODE:
3059 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3060 return emitOpError(invalidNewValAttr());
3061 break;
3062 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3063 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3064 return emitOpError(invalidNewValAttr());
3065 break;
3066 case NVVM::TensormapField::FILL_MODE:
3067 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3068 return emitOpError(invalidNewValAttr());
3069 break;
3070 }
3071
3072 return success();
3073}
3074
3075/// Packs the given `field` into the `result`.
3076/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
3077static llvm::Value *
3078packValInto64Bits(llvm::IRBuilderBase &builder,
3079 llvm::Value *result, // the `result` (unset bits are zero)
3080 llvm::Value *field, // `field` to pack into `result`
3081 unsigned sizeInBits, // Size of `field` in bits
3082 unsigned start) { // Starting bit within `result`
3083 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3084
3085 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3086 if (mask != 0xffffffffu)
3087 field = builder.CreateAnd(field, builder.getInt32(mask));
3088
3089 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3090 field = builder.CreateShl(field, start);
3091
3092 return builder.CreateOr(result, field);
3093}
3094
3095void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
3097 llvm::IRBuilderBase &builder) {
3098 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3099 llvm::Value *smemDesc = builder.getInt64(0);
3100
3101 smemDesc = packValInto64Bits(builder, smemDesc,
3102 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
3103 smemDesc = packValInto64Bits(
3104 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3105 smemDesc = packValInto64Bits(
3106 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3107
3108 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
3109 smemDesc = packValInto64Bits(builder, smemDesc,
3110 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
3111 smemDesc = packValInto64Bits(
3112 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3113 smemDesc = packValInto64Bits(builder, smemDesc,
3114 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
3115
3116 mt.mapValue(thisOp.getRes()) = smemDesc;
3117}
3118
3119//===----------------------------------------------------------------------===//
3120// getPtx methods
3121//===----------------------------------------------------------------------===//
3122
3123std::string NVVM::MBarrierInitOp::getPtx() {
3124 bool isShared = isPtrInSharedCTASpace(getAddr());
3125 return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
3126 : std::string("mbarrier.init.b64 [%0], %1;");
3127}
3128
3129std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3130 bool isShared = isPtrInSharedCTASpace(getAddr());
3131 return isShared
3132 ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3133 : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3134}
3135
3136std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3137 bool isShared = isPtrInSharedCTASpace(getAddr());
3138 llvm::StringRef space = isShared ? ".shared" : "";
3139
3140 return llvm::formatv("{\n\t"
3141 ".reg .pred P1; \n\t"
3142 "LAB_WAIT: \n\t"
3143 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3144 "@P1 bra.uni DONE; \n\t"
3145 "bra.uni LAB_WAIT; \n\t"
3146 "DONE: \n\t"
3147 "}",
3148 space);
3149}
3150
3151//===----------------------------------------------------------------------===//
3152// getIntrinsicID/getIntrinsicIDAndArgs methods
3153//===----------------------------------------------------------------------===//
3154
3155mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
3156 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3157 auto thisOp = cast<NVVM::BarrierOp>(op);
3158 llvm::Value *barrierId = thisOp.getBarrierId()
3159 ? mt.lookupValue(thisOp.getBarrierId())
3160 : builder.getInt32(0);
3161 llvm::Intrinsic::ID id;
3162 llvm::SmallVector<llvm::Value *> args = {barrierId};
3163 if (thisOp.getNumberOfThreads()) {
3164 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3165 args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
3166 } else if (thisOp.getReductionOp()) {
3167 switch (*thisOp.getReductionOp()) {
3168 case NVVM::BarrierReduction::AND:
3169 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3170 break;
3171 case NVVM::BarrierReduction::OR:
3172 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3173 break;
3174 case NVVM::BarrierReduction::POPC:
3175 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3176 break;
3177 }
3178 args.push_back(builder.CreateICmpNE(
3179 mt.lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3180 } else {
3181 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3182 }
3183
3184 return {id, std::move(args)};
3185}
3186
3188PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3189 llvm::IRBuilderBase &builder) {
3190 auto thisOp = cast<NVVM::PMEventOp>(op);
3191 llvm::Type *i16Ty = llvm::Type::getInt16Ty(mt.getLLVMContext());
3192
3193 // With event-id, mask is generated as (1 << event-id)
3194 llvm::Value *maskVal;
3195 if (auto eventAttr = thisOp.getEventIdAttr()) {
3196 uint16_t mask = static_cast<uint16_t>(1u << eventAttr.getInt());
3197 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3198 } else {
3199 maskVal =
3200 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3201 }
3202
3203 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3204}
3205
3206mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
3207 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3208 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3209 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3210 llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3211 : llvm::Intrinsic::nvvm_mbarrier_init;
3212
3213 // Fill the Intrinsic Args
3215 args.push_back(mt.lookupValue(thisOp.getAddr()));
3216 args.push_back(mt.lookupValue(thisOp.getCount()));
3217
3218 return {id, std::move(args)};
3219}
3220
3221mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
3222 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3223 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3224 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3225 llvm::Intrinsic::ID id = isShared
3226 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3227 : llvm::Intrinsic::nvvm_mbarrier_inval;
3228
3229 return {id, {mt.lookupValue(thisOp.getAddr())}};
3230}
3231
3232mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
3233 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3234 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3235
3236 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3237 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3238 // bit-0: Space
3239 // bit-1: Scope
3240 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3241
3242 static constexpr llvm::Intrinsic::ID IDs[] = {
3243 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3244 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3245 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3246 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3247
3248 // Fill the Intrinsic Args
3250 args.push_back(mt.lookupValue(thisOp.getAddr()));
3251 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3252
3253 return {IDs[index], std::move(args)};
3254}
3255
3256mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
3257 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3258 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3259
3260 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3261 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3262 // bit-0: Space
3263 // bit-1: Scope
3264 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3265
3266 static constexpr llvm::Intrinsic::ID IDs[] = {
3267 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3268 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3269 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3270 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3271
3272 // Fill the Intrinsic Args
3274 args.push_back(mt.lookupValue(thisOp.getAddr()));
3275 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3276
3277 return {IDs[index], std::move(args)};
3278}
3279
3280mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
3281 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3282 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3283
3284 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3285 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3286 // bit-0: Space
3287 // bit-1: Scope
3288 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3289
3290 static constexpr llvm::Intrinsic::ID IDs[] = {
3291 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3292 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3293 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3294 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3295 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3296 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3297 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3298 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3299 llvm::Intrinsic::
3300 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3301 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3302
3303 // Tidy-up the Intrinsic Args
3304 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3305 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3306 if (needCast)
3307 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3308
3309 // We have the most basic mbarrier.arrive supported on sm_80.
3310 // It supports: Space=cta, scope=cta, No relaxed, No explicit count.
3311 // So, only for this combination use the legacy intrinsic.
3312 bool hasCount = static_cast<bool>(thisOp.getCount());
3313 if (!hasCount &&
3314 (id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3315 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3316
3317 // When count is not explicitly specified, the default is 1.
3318 llvm::LLVMContext &ctx = mt.getLLVMContext();
3319 llvm::Value *count =
3320 hasCount ? mt.lookupValue(thisOp.getCount())
3321 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3322 return {id, {mbar, count}};
3323}
3324
3325mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
3326 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3327 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3328
3329 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3330 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3331 // bit-0: Space
3332 // bit-1: Scope
3333 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3334
3335 static constexpr llvm::Intrinsic::ID IDs[] = {
3336 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3337 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3338 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3339 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3340 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3341 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3342 llvm::Intrinsic::
3343 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3344 llvm::Intrinsic::
3345 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3346 llvm::Intrinsic::
3347 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3348 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3349
3350 // Tidy-up the Intrinsic Args
3351 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3352 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3353 if (needCast)
3354 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3355
3356 // When count is not explicitly specified, the default is 1.
3357 llvm::LLVMContext &ctx = mt.getLLVMContext();
3358 bool hasCount = static_cast<bool>(thisOp.getCount());
3359 llvm::Value *count =
3360 hasCount ? mt.lookupValue(thisOp.getCount())
3361 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3362
3363 return {id, {mbar, count}};
3364}
3365
3366bool MBarrierArriveExpectTxOp::getAsmValues(
3367 RewriterBase &rewriter,
3368 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3369 &asmValues) {
3370 // Add all the operands but not the attrs to the asmValues list.
3371 // The attrs here are used to generate the right variants for
3372 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3373 for (auto val : getOperands())
3374 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3375
3376 return false;
3377}
3378
3379mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
3380 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3381 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3382
3383 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3384 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3385 // bit-0: Space
3386 // bit-1: Scope
3387 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3388
3389 // clang-format off
3390 static constexpr llvm::Intrinsic::ID IDs[] = {
3391 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3392 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3393 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3394 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3395 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3396 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3397 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3398 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3399 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3400 // clang-format on
3401 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3402
3403 // Tidy-up the Intrinsic Args
3404 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3405 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3406 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3407 if (needCast)
3408 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3409
3410 return {id, {mbar, txcount}};
3411}
3412
3413mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
3414 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3415 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3416
3417 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3418 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3419 // bit-0: Space
3420 // bit-1: Scope
3421 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3422
3423 // clang-format off
3424 static constexpr llvm::Intrinsic::ID IDs[] = {
3425 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3426 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3427 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3428 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3429 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3430 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3431 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3432 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3433 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3434 // clang-format on
3435 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3436
3437 // Tidy-up the Intrinsic Args
3438 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3439 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3440 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3441 if (needCast)
3442 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3443
3444 return {id, {mbar, txcount}};
3445}
3446
3447mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
3448 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3449 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3450 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3451 llvm::Intrinsic::ID id =
3452 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3453 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3454 // Fill the Intrinsic Args
3456 args.push_back(mt.lookupValue(thisOp.getAddr()));
3457 args.push_back(mt.lookupValue(thisOp.getCount()));
3458
3459 return {id, std::move(args)};
3460}
3461
3462mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
3463 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3464 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3465 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3466 llvm::Intrinsic::ID id =
3467 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3468 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3469 // Fill the Intrinsic Args
3471 args.push_back(mt.lookupValue(thisOp.getAddr()));
3472 args.push_back(mt.lookupValue(thisOp.getCount()));
3473
3474 return {id, std::move(args)};
3475}
3476
3477mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
3478 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3479 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3480 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3481 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3482 // bit-0: isPhaseParity
3483 // bit-1: Scope
3484 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3485
3486 // clang-format off
3487 static constexpr llvm::Intrinsic::ID IDs[] = {
3488 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3489 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3490 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3491 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3492 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3493 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3494 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3495 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3496 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3497 // clang-format on
3498 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3499
3500 // Tidy-up the Intrinsic Args
3501 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3502 llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
3503 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3504 if (needCast)
3505 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3506
3507 return {id, {mbar, input}};
3508}
3509
3510mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
3511 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3512 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3513 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3514 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3515 bool hasTicks = static_cast<bool>(thisOp.getTicks());
3516 // bit-0: isPhaseParity
3517 // bit-1: Scope
3518 // bit-2: hasTicks
3519 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3520 (isPhaseParity ? 1 : 0);
3521
3522 // clang-format off
3523 static constexpr llvm::Intrinsic::ID IDs[] = {
3524 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3525 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3526 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3527 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3528 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3529 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3530 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3531 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3532 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3533 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3534 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3535 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3536 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3537 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3538 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3539 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3540 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3541 // clang-format on
3542 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3543
3544 // Tidy-up the mbarrier pointer
3545 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3546 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3547 if (needCast)
3548 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3549
3550 // Fill the Intrinsic Args
3552 args.push_back(mbar);
3553 args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
3554 if (hasTicks)
3555 args.push_back(mt.lookupValue(thisOp.getTicks()));
3556
3557 return {id, std::move(args)};
3558}
3559
3560mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
3561 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3562 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3563 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3564
3565 llvm::Intrinsic::ID id;
3566 if (thisOp.getNoinc()) {
3567 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3568 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3569 } else {
3570 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3571 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3572 }
3573
3574 return {id, {mt.lookupValue(thisOp.getAddr())}};
3575}
3576
3577#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3578 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3579
3580#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3581 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
3582
3583llvm::Intrinsic::ID
3584CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3586 llvm::Intrinsic::ID id;
3587
3588 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
3589 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
3590 switch (cpAsyncOp.getSize()) {
3591 case 4:
3592 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
3593 break;
3594 case 8:
3595 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
3596 break;
3597 case 16:
3598 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
3599 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
3600 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
3601 break;
3602 default:
3603 llvm_unreachable("Invalid copy size in CpAsyncOp.");
3604 }
3605
3606 // Fill the Intrinsic Args
3607 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
3608 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
3609 if (hasCpSize)
3610 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
3611
3612 return id;
3613}
3614
3615mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
3616 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3617 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
3619 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
3620
3621 // Fill the Intrinsic Args
3622 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3623 args.push_back(mt.lookupValue(thisOp.getSize()));
3624
3625 mlir::Value cacheHint = thisOp.getL2CacheHint();
3626 const bool hasCacheHint = static_cast<bool>(cacheHint);
3627 llvm::Value *i64Unused =
3628 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3629 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3630 args.push_back(builder.getInt1(hasCacheHint));
3631
3632 return {id, std::move(args)};
3633}
3634
3635mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3636 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3637 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
3639
3640 // Fill the Intrinsic Args: dst, mbar, src, size.
3641 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3642 args.push_back(mt.lookupValue(thisOp.getMbar()));
3643 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3644 args.push_back(mt.lookupValue(thisOp.getSize()));
3645
3646 // Multicast mask for shared::cluster only, if available.
3647 mlir::Value multicastMask = thisOp.getMulticastMask();
3648 const bool hasMulticastMask = static_cast<bool>(multicastMask);
3649 const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
3650 if (!isSharedCTA) {
3651 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
3652 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
3653 : i16Unused);
3654 }
3655
3656 // Cache hint, if available.
3657 mlir::Value cacheHint = thisOp.getL2CacheHint();
3658 const bool hasCacheHint = static_cast<bool>(cacheHint);
3659 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
3660 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3661
3662 // Flag arguments for multicast and cachehint.
3663 if (!isSharedCTA)
3664 args.push_back(builder.getInt1(hasMulticastMask));
3665 args.push_back(builder.getInt1(hasCacheHint));
3666
3667 llvm::Intrinsic::ID id =
3668 isSharedCTA
3669 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
3670 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
3671
3672 return {id, std::move(args)};
3673}
3674
3675mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3676 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3677 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
3679 llvm::Intrinsic::ID id =
3680 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
3681
3682 // Fill the Intrinsic Args
3683 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3684 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3685 args.push_back(mt.lookupValue(thisOp.getSize()));
3686
3687 mlir::Value cacheHint = thisOp.getL2CacheHint();
3688 const bool hasCacheHint = static_cast<bool>(cacheHint);
3689 llvm::Value *i64Unused =
3690 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3691 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3692 args.push_back(builder.getInt1(hasCacheHint));
3693
3694 // Choose the bytemask variant
3695 if (mlir::Value byteMask = thisOp.getByteMask()) {
3696 args.push_back(mt.lookupValue(byteMask));
3697 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
3698 }
3699
3700 return {id, std::move(args)};
3701}
3702
3703bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
3704 RewriterBase &rewriter,
3705 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3706 &asmValues) {
3707 // Add all the operands but not the attrs to the asmValues list.
3708 // The attrs here are used to generate the right variants for
3709 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3710 for (auto val : getOperands())
3711 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3712
3713 return false;
3714}
3715
3717CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3718 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3719 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
3720 const bool isCTAOnly = thisOp.getIsCTAOnly();
3722
3723 // Fill the Intrinsic Args
3724 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3725 args.push_back(mt.lookupValue(thisOp.getMbar()));
3726 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3727
3728 // Coordinates and im2col-offsets
3729 for (mlir::Value v : thisOp.getCoordinates())
3730 args.push_back(mt.lookupValue(v));
3731 for (mlir::Value v : thisOp.getIm2colOffsets())
3732 args.push_back(mt.lookupValue(v));
3733
3734 // MulticastMask, if available
3735 mlir::Value mcMask = thisOp.getMulticastMask();
3736 const bool hasMC = static_cast<bool>(mcMask);
3737 llvm::Value *i16Zero =
3738 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
3739
3740 // CacheHint, if available
3741 mlir::Value cacheHint = thisOp.getL2CacheHint();
3742 const bool hasCacheHint = static_cast<bool>(cacheHint);
3743 llvm::Value *i64Zero =
3744 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3745
3746 // Flag argument CTAGroup
3747 // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
3748 // Hence, the +1 to getGroup().
3749 const int32_t val =
3750 thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
3751 llvm::Value *cg =
3752 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
3753
3754 if (!isCTAOnly) {
3755 // For shared::cluster, all the arguments that we build are applicable.
3756 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
3757 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
3758 args.push_back(builder.getInt1(hasMC));
3759 args.push_back(builder.getInt1(hasCacheHint));
3760 args.push_back(cg);
3761 } else {
3762 // For shared::cta, only cache-hint is applicable.
3763 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
3764 args.push_back(builder.getInt1(hasCacheHint));
3765 }
3766
3767 constexpr size_t numDims = 5; // 1D to 5D
3768 constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
3769 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
3770 using TableTy = std::array<rowTy, numModes>;
3771 static constexpr TableTy IDTable{
3772 {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
3773 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
3774 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
3775 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
3776 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
3778 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
3779 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
3780 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
3782 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
3783 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
3784 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
3786 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
3787 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
3788 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
3790 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
3791
3792 static constexpr TableTy IDTableCTA{
3793 {{notIntrinsic,
3794 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
3795 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
3796 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
3797 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
3798 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
3800 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
3801 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
3802 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
3804 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
3805 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
3806 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
3808 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
3809 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
3810 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
3812 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
3813
3814 static_assert(
3815 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
3816 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
3817 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
3818 size_t mode = static_cast<size_t>(thisOp.getMode());
3819 size_t dim = thisOp.getCoordinates().size();
3820 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
3821 assert(id != notIntrinsic &&
3822 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
3823
3824 return {id, std::move(args)};
3825}
3826
3827mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
3828 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3829 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
3831
3832 // Fill the Intrinsic Args
3833 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3834
3835 for (auto v : thisOp.getCoordinates())
3836 args.push_back(mt.lookupValue(v));
3837 for (auto v : thisOp.getIm2colOffsets())
3838 args.push_back(mt.lookupValue(v));
3839
3840 mlir::Value cacheHint = thisOp.getL2CacheHint();
3841 const bool hasCacheHint = static_cast<bool>(cacheHint);
3842 llvm::Value *i64Unused =
3843 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3844 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3845 args.push_back(builder.getInt1(hasCacheHint));
3846
3847 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3848 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3849 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
3850 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
3851 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
3852 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
3853 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
3854 {NI, NI, NI,
3855 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
3856 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
3857 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
3858 {NI, NI, NI,
3859 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
3860 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
3861 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
3862 {NI, NI, NI,
3863 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
3864 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
3865 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
3866 {NI, NI, NI, NI, NI,
3867 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
3868
3869 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
3870 "TMALoadModes must match number of rows in IDTable");
3871 size_t mode = static_cast<size_t>(thisOp.getMode());
3872 size_t dim = thisOp.getCoordinates().size();
3873 llvm::Intrinsic::ID id = IDTable[mode][dim];
3874 if (id == llvm::Intrinsic::not_intrinsic)
3875 llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
3876
3877 return {id, std::move(args)};
3878}
3879
3881CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3882 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3883 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
3885
3886 // Fill the Intrinsic Args
3887 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3888 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3889
3890 for (auto v : thisOp.getCoordinates())
3891 args.push_back(mt.lookupValue(v));
3892
3893 mlir::Value cacheHint = thisOp.getL2CacheHint();
3894 const bool hasCacheHint = static_cast<bool>(cacheHint);
3895 llvm::Value *i64Unused =
3896 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3897 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3898 args.push_back(builder.getInt1(hasCacheHint));
3899
3900 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3901 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3902 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
3903 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
3904 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
3905 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
3906 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
3907 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
3908 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
3909 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
3910 {NI, NI, NI, NI, NI,
3911 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
3912
3913 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
3914 "TMAStoreModes must match number of rows in IDTable");
3915 size_t mode = static_cast<size_t>(thisOp.getMode());
3916 size_t dim = thisOp.getCoordinates().size();
3917 llvm::Intrinsic::ID id = IDTable[mode][dim];
3918 if (id == llvm::Intrinsic::not_intrinsic)
3919 llvm_unreachable(
3920 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
3921
3922 return {id, std::move(args)};
3923}
3924
3925NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
3926 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3927 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
3928 llvm::LLVMContext &ctx = mt.getLLVMContext();
3929
3931
3932 // Arguments to the intrinsic:
3933 // shared_mem_ptr, tmaDesc, tensorDims
3934 // cache_hint(if applicable) and flag(boolean)
3935 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3936 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3937
3938 for (Value v : thisOp.getCoordinates())
3939 args.push_back(mt.lookupValue(v));
3940
3941 mlir::Value cacheHint = thisOp.getL2CacheHint();
3942 const bool hasCacheHint = static_cast<bool>(cacheHint);
3943 llvm::Value *i64ZeroValue =
3944 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
3945 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
3946 args.push_back(builder.getInt1(hasCacheHint));
3947
3948 const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
3949
3950 constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
3951 constexpr unsigned numLayouts = 2; // TILE, IM2COL
3952 constexpr unsigned maxDim = 5; // 1D to 5D
3953 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
3954 using layoutTable = std::array<row, numLayouts>;
3955 using fullTable = std::array<layoutTable, numRedKinds>;
3956 static constexpr fullTable IDTable{
3957 {// RedTy::ADD
3958 {{{{notIntrinsic,
3959 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
3960 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
3961 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
3962 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
3963 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
3965 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
3966 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
3967 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
3968 // RedTy::MIN
3969 {{{{notIntrinsic,
3970 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
3971 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
3972 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
3973 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
3974 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
3976 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
3977 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
3978 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
3979 // RedTy::MAX
3980 {{{{notIntrinsic,
3981 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
3982 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
3983 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
3984 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
3985 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
3987 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
3988 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
3989 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
3990 // RedTy::INC
3991 {{{{notIntrinsic,
3992 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
3993 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
3994 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
3995 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
3996 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
3998 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
3999 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4000 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4001 // RedTy::DEC
4002 {{{{notIntrinsic,
4003 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4004 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4005 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4006 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4007 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4009 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4010 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4011 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4012 // RedTy::AND
4013 {{{{notIntrinsic,
4014 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4015 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4016 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4017 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4018 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4020 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4021 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4022 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4023 // RedTy::OR
4024 {{{{notIntrinsic,
4025 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4026 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4027 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4028 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4029 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4031 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4032 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4033 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4034 // RedTy::XOR
4035 {{{{notIntrinsic,
4036 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4037 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4038 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4039 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4040 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4042 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4043 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4044 llvm::Intrinsic::
4045 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4046
4047 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4048 "TMAReduxKinds must match number of rows in IDTable");
4049
4050 size_t redKind = static_cast<size_t>(thisOp.getRedKind());
4051 size_t mode = static_cast<size_t>(thisOp.getMode());
4052 size_t dim = thisOp.getCoordinates().size();
4053
4054 assert(redKind < IDTable.size() &&
4055 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4056 assert(mode < IDTable[redKind].size() &&
4057 "Invalid mode for CpAsyncBulkTensorReduceOp");
4058 assert(dim < IDTable[redKind][mode].size() &&
4059 "Invalid dim for CpAsyncBulkTensorReduceOp");
4060
4061 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4062
4063 assert(intrinsicID != notIntrinsic &&
4064 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4065
4066 return {intrinsicID, std::move(args)};
4067}
4068
4069#define _none
4070
4071#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4072 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4073 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4074
4075#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4076 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4077 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4078
4079llvm::Intrinsic::ID
4080ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4081 NVVM::SaturationMode sat, bool hasRelu) {
4082 using RndMode = NVVM::FPRoundingMode;
4083 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4084 switch (rnd) {
4085 case RndMode::RN:
4086 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
4087 case RndMode::RZ:
4088 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
4089 case RndMode::RNA:
4090 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
4091 default:
4092 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
4093 }
4094}
4095
4097ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4099 llvm::IRBuilderBase &builder) {
4101 args.push_back(mt.lookupValue(op.getA()));
4102 args.push_back(mt.lookupValue(op.getB()));
4103
4104 bool hasRelu = op.getRelu();
4105
4106 llvm::Intrinsic::ID intId =
4107 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4108 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4109
4110 return {intId, std::move(args)};
4111}
4112
4113#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4114 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4115 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4116
4117llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4118 bool hasRelu) {
4120 .Case([&](mlir::Float6E2M3FNType) {
4121 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
4122 })
4123 .Case([&](mlir::Float6E3M2FNType) {
4124 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
4125 })
4126 .Default([](mlir::Type) {
4127 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
4128 return llvm::Intrinsic::not_intrinsic;
4129 });
4130}
4131
4132#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4133 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4134 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4135
4136#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4137 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4138 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4139
4140llvm::Intrinsic::ID
4141ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4142 NVVM::SaturationMode sat, bool hasRelu) {
4143 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4144 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4145 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4146
4148 .Case([&](mlir::Float8E4M3FNType) {
4149 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
4150 })
4151 .Case([&](mlir::Float8E5M2Type) {
4152 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
4153 })
4154 .Case([&](mlir::Float8E8M0FNUType) {
4155 if (hasRoundingModeRZ)
4156 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
4157 else if (hasRoundingModeRP)
4158 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
4159
4160 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4161 })
4162 .Default([](mlir::Type) {
4163 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4164 return llvm::Intrinsic::not_intrinsic;
4165 });
4166}
4167
4168#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4169 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4170 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4171
4172llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4173 bool hasRelu) {
4175 .Case([&](mlir::Float8E4M3FNType) {
4176 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
4177 })
4178 .Case([&](mlir::Float8E5M2Type) {
4179 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
4180 })
4181 .Default([](mlir::Type) {
4182 llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
4183 return llvm::Intrinsic::not_intrinsic;
4184 });
4185}
4186
4187#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
4188 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
4189 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
4190
4191llvm::Intrinsic::ID
4192ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4193 NVVM::SaturationMode sat) {
4194 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4195 switch (rnd) {
4196 case NVVM::FPRoundingMode::RZ:
4197 return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
4198 case NVVM::FPRoundingMode::RP:
4199 return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
4200 default:
4201 llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
4202 }
4203}
4204
4205NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
4206 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4207 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4208
4209 bool hasRelu = curOp.getRelu();
4210
4211 llvm::Intrinsic::ID intId =
4213 .Case([&](Float8E4M3FNType type) {
4214 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4215 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4216 })
4217 .Case([&](Float8E5M2Type type) {
4218 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4219 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4220 })
4221 .Default([](mlir::Type type) {
4222 llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
4223 return llvm::Intrinsic::not_intrinsic;
4224 });
4225
4226 llvm::Value *packedI16 =
4227 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4228 llvm::Type::getInt16Ty(builder.getContext()));
4229
4230 return {intId, {packedI16}};
4231}
4232
4233NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
4234 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4235 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4236
4237 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4238 llvm::Value *packedI16 =
4239 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4240 llvm::Type::getInt16Ty(builder.getContext()));
4241
4242 return {intId, {packedI16}};
4243}
4244
4245NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
4246 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4247 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4248
4249 bool hasRelu = curOp.getRelu();
4250
4251 llvm::Intrinsic::ID intId =
4253 .Case([&](Float6E2M3FNType type) {
4254 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4255 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4256 })
4257 .Case([&](Float6E3M2FNType type) {
4258 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4259 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4260 })
4261 .Default([](mlir::Type type) {
4262 llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
4263 return llvm::Intrinsic::not_intrinsic;
4264 });
4265
4266 llvm::Value *packedI16 =
4267 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4268 llvm::Type::getInt16Ty(builder.getContext()));
4269
4270 return {intId, {packedI16}};
4271}
4272
4273NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
4274 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4275 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4276
4277 bool hasRelu = curOp.getRelu();
4278
4279 llvm::Intrinsic::ID intId =
4281 .Case([&](Float4E2M1FNType type) {
4282 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4283 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4284 })
4285 .Default([](mlir::Type type) {
4286 llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
4287 return llvm::Intrinsic::not_intrinsic;
4288 });
4289
4290 llvm::Value *extendedI16 =
4291 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
4292 llvm::Type::getInt16Ty(builder.getContext()));
4293
4294 return {intId, {extendedI16}};
4295}
4296
4297llvm::Intrinsic::ID
4298Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
4301 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4302 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4303 .getAddressSpace();
4304 bool isShared = as == NVVMMemorySpace::Shared;
4305 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4306
4307 llvm::Intrinsic::ID id;
4308 if (isShared) {
4309 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4310 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4311 } else {
4312 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4313 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4314 }
4315
4316 // Fill the Intrinsic Args
4317 args.push_back(mt.lookupValue(curOp.getAddr()));
4318 args.push_back(mt.lookupValue(curOp.getNCols()));
4319
4320 return id;
4321}
4322
4323llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4326 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4327 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4328 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4329 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4330
4331 // Fill the Intrinsic Args
4332 args.push_back(mt.lookupValue(curOp.getTaddr()));
4333 args.push_back(mt.lookupValue(curOp.getNCols()));
4334
4335 return id;
4336}
4337
4338#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4339 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4340 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4341
4342#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4343 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4344 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4345
4346llvm::Intrinsic::ID
4347Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
4350 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4351 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4352 .getAddressSpace();
4353 bool isShared = as == NVVMMemorySpace::Shared;
4354 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
4355 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4356
4357 llvm::Intrinsic::ID id =
4358 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
4359 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
4360
4361 // Fill the Intrinsic Args
4362 args.push_back(mt.lookupValue(curOp.getAddr()));
4363 if (hasMulticast)
4364 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
4365
4366 return id;
4367}
4368
4369#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4370 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4371
4372#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4373 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4374 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4375
4376#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4377 [&]() -> auto { \
4378 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4379 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4380 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4381 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4382 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4383 }()
4384
4386ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4388 llvm::IRBuilderBase &builder) {
4389 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4390 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4391 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4392 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4393 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4394 };
4395 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4396 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4397 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4398 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4399 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4400 };
4401 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4402 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4403 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4404 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4405 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4406 };
4407
4408 unsigned hasRelu = op.getRelu() ? 1 : 0;
4409 unsigned hasSatFinite =
4410 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4411 // idx: bit-0 - relu
4412 // bit-1 - satfinite
4413 unsigned idx = (hasSatFinite << 1) | hasRelu;
4414
4416 args.push_back(mt.lookupValue(op.getSrcHi()));
4417 args.push_back(mt.lookupValue(op.getSrcLo()));
4418 if (op.getRandomBits())
4419 args.push_back(mt.lookupValue(op.getRandomBits()));
4420
4421 switch (op.getRnd()) {
4422 case FPRoundingMode::RN:
4423 return {rndRNIds[idx], std::move(args)};
4424 case FPRoundingMode::RZ:
4425 return {rndRZIds[idx], std::move(args)};
4426 case FPRoundingMode::RS:
4427 return {rndRSIds[idx], std::move(args)};
4428 default:
4429 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
4430 }
4431}
4432
4434ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
4436 llvm::IRBuilderBase &builder) {
4437 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4438 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
4439 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
4440 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
4441 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
4442 };
4443 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4444 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
4445 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
4446 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
4447 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
4448 };
4449 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4450 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
4451 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
4452 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
4453 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
4454 };
4455
4456 unsigned hasRelu = op.getRelu() ? 1 : 0;
4457 unsigned hasSatFinite =
4458 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4459 // idx: bit-0 - relu
4460 // bit-1 - satfinite
4461 unsigned idx = (hasSatFinite << 1) | hasRelu;
4462
4464 args.push_back(mt.lookupValue(op.getSrcHi()));
4465 args.push_back(mt.lookupValue(op.getSrcLo()));
4466 if (op.getRandomBits())
4467 args.push_back(mt.lookupValue(op.getRandomBits()));
4468
4469 switch (op.getRnd()) {
4470 case FPRoundingMode::RN:
4471 return {rndRNIds[idx], std::move(args)};
4472 case FPRoundingMode::RZ:
4473 return {rndRZIds[idx], std::move(args)};
4474 case FPRoundingMode::RS:
4475 return {rndRSIds[idx], std::move(args)};
4476 default:
4477 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
4478 }
4479}
4480
4481llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
4482 mlir::Type dstTy = getDstTy();
4483 bool hasRelu = getRelu();
4484
4486 .Case([&](mlir::Float8E4M3FNType) {
4487 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
4488 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
4489 })
4490 .Case([&](mlir::Float8E5M2Type) {
4491 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
4492 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
4493 })
4494 .Default([](mlir::Type) {
4495 llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
4496 return llvm::Intrinsic::not_intrinsic;
4497 });
4498}
4499
4500llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
4501 mlir::Type dstTy = getDstTy();
4502 bool hasRelu = getRelu();
4503
4505 .Case([&](mlir::Float6E2M3FNType) {
4506 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
4507 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
4508 })
4509 .Case([&](mlir::Float6E3M2FNType) {
4510 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
4511 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
4512 })
4513 .Default([](mlir::Type) {
4514 llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
4515 return llvm::Intrinsic::not_intrinsic;
4516 });
4517}
4518
4519llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
4520 mlir::Type dstTy = getDstTy();
4521 bool hasRelu = getRelu();
4522
4524 .Case([&](mlir::Float4E2M1FNType) {
4525 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
4526 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
4527 })
4528 .Default([](mlir::Type) {
4529 llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
4530 return llvm::Intrinsic::not_intrinsic;
4531 });
4532}
4533
4534llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
4535 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
4536 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
4537 auto srcFmt = curOp.getSrcFormat();
4538 auto mc = curOp.getMulticast();
4539
4540 switch (curOp.getShape()) {
4541 case Tcgen05CpShape::SHAPE_128x256b:
4542 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
4543 case Tcgen05CpShape::SHAPE_128x128b:
4544 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
4545 case Tcgen05CpShape::SHAPE_4x256b:
4546 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
4547 case Tcgen05CpShape::SHAPE_32x128b:
4548 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
4549 case Tcgen05CpShape::SHAPE_64x128b:
4550 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
4551 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
4552 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
4553 }
4554 llvm_unreachable("Invalid shape in tcgen05 cp Op");
4555}
4556
4557// Returns the valid vector length for a given shape and vector length, the
4558// function models the table mentioned in the tcgen05.{ld, st} Op description
4559static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
4560 unsigned vecLen) {
4561 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
4562 return vecLen >= 2;
4563 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
4564 return vecLen >= 4;
4565 return true;
4566}
4567
4568LogicalResult Tcgen05LdOp::verify() {
4569 LogicalResult result = success();
4570 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4571 result = emitError("shape 16x32bx2 requires offset argument");
4572
4573 if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
4574 result = emitError("offset argument is only supported for shape 16x32bx2");
4575
4576 auto resTy = getRes().getType();
4577 unsigned resLen = isa<VectorType>(resTy)
4578 ? llvm::cast<VectorType>(resTy).getNumElements()
4579 : 1;
4580 if (!isValidVectorLength(getShape(), resLen))
4581 result = emitError(llvm::formatv("invalid result type length {0} for shape "
4582 "{1} in tcgen05.ld Op",
4583 resLen, stringifyEnum(getShape())));
4584
4585 return result;
4586}
4587
4588LogicalResult Tcgen05StOp::verify() {
4589 LogicalResult result = success();
4590 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4591 result = emitError("shape 16x32bx2 requires offset argument");
4592
4593 auto valTy = getVal().getType();
4594 unsigned valLen = isa<VectorType>(valTy)
4595 ? llvm::cast<VectorType>(valTy).getNumElements()
4596 : 1;
4597 if (!isValidVectorLength(getShape(), valLen))
4598 result = emitError(llvm::formatv("invalid input length {0} for shape "
4599 "{1} in tcgen05.st Op",
4600 valLen, stringifyEnum(getShape())));
4601
4602 return result;
4603}
4604
4605/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
4606/// have ConstantRangeAttr.
4609 SetIntRangeFn setResultRanges) {
4610 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
4611 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
4612 rangeAttr.getLower(), rangeAttr.getUpper()});
4613 } else {
4614 setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
4615 }
4616}
4617
4618/// Verify the range attribute satisfies LLVM ConstantRange constructor
4619/// requirements for NVVM SpecialRangeableRegisterOp.
4620static LogicalResult
4622 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
4623 if (!rangeAttr)
4624 return success();
4625
4626 const llvm::APInt &lower = rangeAttr->getLower();
4627 const llvm::APInt &upper = rangeAttr->getUpper();
4628
4629 // Check LLVM ConstantRange constructor condition
4630 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
4631 unsigned bitWidth = lower.getBitWidth();
4632 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
4633 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
4634 return op->emitOpError(
4635 "invalid range attribute: Lower == Upper, but they aren't min (")
4636 << llvm::toString(minVal, 10, false) << ") or max ("
4637 << llvm::toString(maxVal, 10, false)
4638 << ") value! This is an invalid constant range.";
4639 }
4640
4641 return success();
4642}
4643
4644static llvm::Value *getAsPackedI32(llvm::Value *arg,
4645 llvm::IRBuilderBase &builder) {
4646 return builder.CreateBitCast(arg,
4647 llvm::Type::getInt32Ty(builder.getContext()));
4648}
4649
4650NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
4651 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4652 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
4653
4655 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
4656 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
4657 args.push_back(mt.lookupValue(curOp.getC()));
4658
4659 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4660 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4661 unsigned type = (isASigned << 1) | isBSigned;
4662 const llvm::Intrinsic::ID ids[] = {
4663 llvm::Intrinsic::nvvm_idp4a_u_u,
4664 llvm::Intrinsic::nvvm_idp4a_u_s,
4665 llvm::Intrinsic::nvvm_idp4a_s_u,
4666 llvm::Intrinsic::nvvm_idp4a_s_s,
4667 };
4668 return {ids[type], args};
4669}
4670
4671NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
4672 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4673 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
4674
4676 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
4677 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
4678 args.push_back(builder.getInt1(curOp.getBHi()));
4679 args.push_back(mt.lookupValue(curOp.getC()));
4680
4681 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4682 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4683 unsigned type = (isASigned << 1) | isBSigned;
4684 const llvm::Intrinsic::ID ids[] = {
4685 llvm::Intrinsic::nvvm_idp2a_u_u,
4686 llvm::Intrinsic::nvvm_idp2a_u_s,
4687 llvm::Intrinsic::nvvm_idp2a_s_u,
4688 llvm::Intrinsic::nvvm_idp2a_s_s,
4689 };
4690 return {ids[type], args};
4691}
4692
4693static llvm::Value *getParamCastedAddr(llvm::Value *addr,
4694 llvm::IRBuilderBase &builder) {
4695 return builder.CreateAddrSpaceCast(
4696 addr,
4697 llvm::PointerType::get(builder.getContext(),
4698 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
4699}
4700
4702PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
4704 llvm::IRBuilderBase &builder) {
4705 using MemSpace = NVVM::NVVMMemorySpace;
4706 using CacheLevel = NVVM::PrefetchCacheLevel;
4707
4708 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
4709 std::optional<NVVM::CacheEvictionPriority> evictPriority =
4710 op.getEvictPriority();
4711 unsigned addressSpace =
4712 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
4713 .getAddressSpace();
4714
4716 llvm::Value *addr = mt.lookupValue(op.getAddr());
4717 args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
4718 : addr);
4719
4720 if (op.getTensormap())
4721 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
4722
4723 assert(cacheLevel && "expected cache level for non-tensormap prefetch");
4724
4725 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
4726 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
4727
4728 if (evictPriority && *cacheLevel == CacheLevel::L2) {
4729 switch (*evictPriority) {
4730 case NVVM::CacheEvictionPriority::EvictLast:
4731 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
4732 case NVVM::CacheEvictionPriority::EvictNormal:
4733 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
4734 default:
4735 llvm_unreachable("Invalid cache eviction priority");
4736 }
4737 }
4738
4739 switch (static_cast<MemSpace>(addressSpace)) {
4740 case MemSpace::Generic:
4741 return *cacheLevel == CacheLevel::L1
4742 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
4743 : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
4744 case MemSpace::Global:
4745 return *cacheLevel == CacheLevel::L1
4747 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
4748 : NVVM::IDArgPair(
4749 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
4750 case MemSpace::Local:
4751 return *cacheLevel == CacheLevel::L1
4753 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
4754 : NVVM::IDArgPair(
4755 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
4756 default:
4757 llvm_unreachable("Invalid pointer address space");
4758 }
4759}
4760
4761bool NVVM::InlinePtxOp::getAsmValues(
4762 RewriterBase &rewriter,
4763 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
4764 &asmValues) {
4765 for (auto arg : getReadWriteArgs())
4766 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
4767 for (auto arg : getResults())
4768 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
4769 for (auto arg : getReadOnlyArgs())
4770 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
4771 if (getPredicate())
4772 asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
4773 return false; // No manual mapping needed
4774}
4775
4776NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
4777 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4778 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
4780 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
4781 args.push_back(mt.lookupValue(curOp.getMbarrier()));
4782
4783 llvm::Intrinsic::ID intrinsicID =
4784 curOp.getMulticast()
4785 ? llvm::Intrinsic::
4786 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
4787 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
4788
4789 return {intrinsicID, args};
4790}
4791
4792NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
4793 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4794 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
4796 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
4797
4798 llvm::Intrinsic::ID intrinsicID;
4799
4800 switch (curOp.getQueryType()) {
4801 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
4802 intrinsicID =
4803 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
4804 break;
4805 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
4806 intrinsicID = llvm::Intrinsic::
4807 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
4808 break;
4809 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
4810 intrinsicID = llvm::Intrinsic::
4811 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
4812 break;
4813 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
4814 intrinsicID = llvm::Intrinsic::
4815 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
4816 break;
4817 }
4818 return {intrinsicID, args};
4819}
4820
4822PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4823 llvm::IRBuilderBase &builder) {
4824 auto thisOp = cast<NVVM::PermuteOp>(op);
4825 NVVM::PermuteMode mode = thisOp.getMode();
4826
4827 static constexpr llvm::Intrinsic::ID IDs[] = {
4828 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
4829 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
4830 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
4831 llvm::Intrinsic::nvvm_prmt_rc16};
4832
4833 unsigned modeIndex = static_cast<unsigned>(mode);
4835 args.push_back(mt.lookupValue(thisOp.getLo()));
4836
4837 // Only first 3 modes (Default, f4e, b4e) need the hi operand.
4838 if (modeIndex < 3)
4839 args.push_back(mt.lookupValue(thisOp.getHi()));
4840
4841 args.push_back(mt.lookupValue(thisOp.getSelector()));
4842
4843 return {IDs[modeIndex], args};
4844}
4845
4846mlir::NVVM::IDArgPair TensormapReplaceOp::getIntrinsicIDAndArgs(
4847 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4848 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
4849
4851 args.push_back(mt.lookupValue(thisOp.getAddr()));
4852 if (thisOp.getOrd())
4853 args.push_back(builder.getInt32(thisOp.getOrd().value()));
4854 if (thisOp.getNewValue())
4855 args.push_back(mt.lookupValue(thisOp.getNewValue()));
4856 if (auto attr = thisOp.getNewValueAttr()) {
4857 auto val =
4859 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
4860 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
4861 TensormapFillModeAttr>([](auto attr) {
4862 return static_cast<unsigned>(attr.getValue());
4863 })
4864 .Default([](auto attr) {
4865 llvm_unreachable("Invalid attribute type");
4866 return 0;
4867 });
4868 args.push_back(builder.getInt32(val));
4869 }
4870
4871 static constexpr llvm::Intrinsic::ID IDs[] = {
4872 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
4873 llvm::Intrinsic::nvvm_tensormap_replace_rank,
4874 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
4875 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
4876 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
4877 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
4878 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
4879 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
4880 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
4881 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
4882 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
4883 };
4884
4885 unsigned fieldIndex = static_cast<unsigned>(thisOp.getField());
4886
4887 return {IDs[fieldIndex], args};
4888}
4889
4890//===----------------------------------------------------------------------===//
4891// NVVM tcgen05.mma functions
4892//===----------------------------------------------------------------------===//
4893
4895Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4896 llvm::IRBuilderBase &builder) {
4897
4898 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
4900
4901 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
4902
4903 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
4904 const bool isATensor = isa<llvm::PointerType>(A->getType());
4905 args.push_back(A);
4906
4907 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
4908 args.push_back(mt.lookupValue(thisOp.getIdesc()));
4909 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
4910
4911 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
4912 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
4913 using IsATensorArray = std::array<CtaGroupArray, 2>;
4914 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
4915 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
4916
4917 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
4918 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
4919 { // without diable output lane
4920 {{// without scale input D
4921 {{
4922 // shared
4923 {{// cg1
4924 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
4925 // cg2
4926 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
4927 {{// tensor
4928 {
4929 // cg1
4930 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
4931 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
4932 },
4933 {
4934 // cg2
4935 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
4936 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
4937 }}},
4938 }},
4939 // with scale input D
4940 {{ // shared
4941 {{// cg1
4942 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
4943 // cg2
4944 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
4945 {{// tensor
4946 {
4947 // cg1
4948 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
4949 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
4950 },
4951 {
4952 // cg2
4953 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
4954 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
4955 }}}}}}},
4956 // with disable output lane
4957 {{ // without scale input D
4958 {{ // shared
4959 {{// cg1
4960 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
4961 notIntrinsic},
4962 // cg2
4963 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
4964 notIntrinsic}}},
4965 {{// cg1
4966 {
4967 llvm::Intrinsic::
4968 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
4969 llvm::Intrinsic::
4970 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
4971 },
4972 // cg2
4973 {
4974 llvm::Intrinsic::
4975 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
4976 llvm::Intrinsic::
4977 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
4978 }}}}},
4979 // with scale input D
4980 {{ // shared
4981 {{// cg1
4982 {llvm::Intrinsic::
4983 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
4984 notIntrinsic},
4985 // cg2
4986 {llvm::Intrinsic::
4987 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
4988 notIntrinsic}}},
4989 // tensor
4990 {{// cg1
4991 {llvm::Intrinsic::
4992 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
4993 llvm::Intrinsic::
4994 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
4995 // cg2
4996 {
4997 llvm::Intrinsic::
4998 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
4999 llvm::Intrinsic::
5000 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5001 }}}}}}}}};
5002
5003 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5004 bool hasScaleInputD = ScaleInputD != nullptr;
5005
5006 llvm::Value *DisableOutputLane =
5007 mt.lookupValue(thisOp.getDisableOutputLane());
5008 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5009
5010 const unsigned ctaGroup =
5011 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5012
5013 llvm::Intrinsic::ID ID =
5014 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5015 [ctaGroup - 1][thisOp.getAShift()];
5016
5017 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
5018
5019 if (hasScaleInputD)
5020 args.push_back(ScaleInputD);
5021
5022 if (hasDisableOutputLane)
5023 args.push_back(DisableOutputLane);
5024
5025 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5026
5027 if (!hasDisableOutputLane)
5028 args.push_back(builder.getInt32(ctaGroup));
5029
5030 args.push_back(
5031 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5032
5033 return {ID, args};
5034}
5035
5036static LogicalResult
5037verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
5038 NVVM::CTAGroupKind ctaGroup, bool hasAShift,
5039 NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
5040
5041 if (disableOutputLane) {
5042 mlir::VectorType disableOutputLaneType =
5043 cast<mlir::VectorType>(disableOutputLane.getType());
5044 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5045 disableOutputLaneType.getNumElements() != 4) ||
5046 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5047 disableOutputLaneType.getNumElements() != 8))
5048 return emitError(loc) << "Disable Output Lane of length "
5049 << disableOutputLaneType.getNumElements()
5050 << " is incompatible with CtaGroupAttr";
5051 }
5052
5053 if (hasAShift && !isATensor)
5054 return emitError(
5055 loc, "A-shift can be applied only when matrix A is in tensor memory");
5056
5057 if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5058 collectorOp == Tcgen05MMACollectorOp::USE))
5059 return emitError(
5060 loc, "Cannot use collector buffer operation fill or use with ashift");
5061
5062 return success();
5063}
5064
5065LogicalResult Tcgen05MMAOp::verify() {
5066 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5067 getDisableOutputLane(), getCtaGroup(), getAShift(),
5068 getCollectorOp(), getLoc());
5069}
5070
5071//===----------------------------------------------------------------------===//
5072// NVVM tcgen05.mma.sp functions
5073//===----------------------------------------------------------------------===//
5074
5075mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
5076 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5077
5078 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5080
5081 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5082
5083 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5084 bool isATensor = isa<llvm::PointerType>(A->getType());
5085 args.push_back(A);
5086
5087 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5088 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5089 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5090 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5091
5092 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5093 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5094 using IsATensorArray = std::array<CtaGroupArray, 2>;
5095 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5096 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5097
5098 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5099 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5100 { // without diable output lane
5101 {{// without scale input D
5102 {{
5103 // shared
5104 {{// cg1
5105 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
5106 // cg2
5107 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
5108 {{// tensor
5109 {
5110 // cg1
5111 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5112 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5113 },
5114 {
5115 // cg2
5116 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5117 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5118 }}},
5119 }},
5120 // with scale input D
5121 {{ // shared
5122 {{// cg1
5123 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5124 notIntrinsic},
5125 // cg2
5126 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5127 notIntrinsic}}},
5128 {{// tensor
5129 {
5130 // cg1
5131 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5132 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5133 },
5134 {
5135 // cg2
5136 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5137 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5138 }}}}}}},
5139 // with disable output lane
5140 {{ // without scale input D
5141 {{ // shared
5142 {{// cg1
5143 {llvm::Intrinsic::
5144 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5145 notIntrinsic},
5146 // cg2
5147 {llvm::Intrinsic::
5148 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5149 notIntrinsic}}},
5150 {{// cg1
5151 {
5152 llvm::Intrinsic::
5153 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5154 llvm::Intrinsic::
5155 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5156 },
5157 // cg2
5158 {
5159 llvm::Intrinsic::
5160 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5161 llvm::Intrinsic::
5162 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5163 }}}}},
5164 // with scale input D
5165 {{ // shared
5166 {{// cg1
5167 {llvm::Intrinsic::
5168 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5169 notIntrinsic},
5170 // cg2
5171 {llvm::Intrinsic::
5172 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5173 notIntrinsic}}},
5174 // tensor
5175 {{// cg1
5176 {llvm::Intrinsic::
5177 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5178 llvm::Intrinsic::
5179 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5180 // cg2
5181 {
5182 llvm::Intrinsic::
5183 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5184 llvm::Intrinsic::
5185 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5186 }}}}}}}}};
5187
5188 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5189 bool hasScaleInputD = ScaleInputD != nullptr;
5190
5191 llvm::Value *DisableOutputLane =
5192 mt.lookupValue(thisOp.getDisableOutputLane());
5193 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5194
5195 unsigned ctaGroup =
5196 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5197
5198 llvm::Intrinsic::ID ID =
5199 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5200 [ctaGroup - 1][thisOp.getAShift()];
5201
5202 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
5203
5204 if (hasScaleInputD)
5205 args.push_back(ScaleInputD);
5206
5207 if (hasDisableOutputLane)
5208 args.push_back(DisableOutputLane);
5209
5210 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5211
5212 if (!hasDisableOutputLane)
5213 args.push_back(builder.getInt32(ctaGroup));
5214
5215 args.push_back(
5216 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5217
5218 return {ID, args};
5219}
5220
5221LogicalResult Tcgen05MMASparseOp::verify() {
5222 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5223 getDisableOutputLane(), getCtaGroup(), getAShift(),
5224 getCollectorOp(), getLoc());
5225}
5226
5227//===----------------------------------------------------------------------===//
5228// NVVM tcgen05.mma.block_scale functions
5229//===----------------------------------------------------------------------===//
5230
5231mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
5232 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5233
5234 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5236
5237 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5238
5239 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5240 bool isATensor = isa<llvm::PointerType>(A->getType());
5241 args.push_back(A);
5242
5243 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5244 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5245 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5246 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5247 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5248 args.push_back(builder.getInt32(
5249 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5250 args.push_back(
5251 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5252
5253 auto kind = thisOp.getKind();
5254 auto blockScale = thisOp.getBlockScale();
5255 llvm::Intrinsic::ID ID = [&]() {
5256 if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
5257 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5258 return isATensor ? llvm::Intrinsic::
5259 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5260 : llvm::Intrinsic::
5261 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5262 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5263 return isATensor
5264 ? llvm::Intrinsic::
5265 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5266 : llvm::Intrinsic::
5267 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5268 }
5269 } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
5270 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5271 return isATensor
5272 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5273 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5274 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5275 return isATensor ? llvm::Intrinsic::
5276 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5277 : llvm::Intrinsic::
5278 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5279 }
5280 } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
5281 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5282 return isATensor
5283 ? llvm::Intrinsic::
5284 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5285 : llvm::Intrinsic::
5286 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5287
5288 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5289 return isATensor
5290 ? llvm::Intrinsic::
5291 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5292 : llvm::Intrinsic::
5293 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5294 }
5295 }
5296 llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
5297 }();
5298
5299 return {ID, args};
5300}
5301
5302static LogicalResult verifyTcgen05MMABlockScaleOp(
5303 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind,
5304 NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
5305
5306 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5307 kind == MMABlockScaleKind::MXF4NVF4)
5308 return emitError(loc, "mxf4nvf4 requires block scale attribute");
5309
5310 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5311 kind != MMABlockScaleKind::MXF4NVF4)
5312 return emitError(loc,
5313 llvm::formatv("{} kind does not support block16 attribute",
5314 stringifyEnum(kind)));
5315
5316 return success();
5317}
5318
5319LogicalResult Tcgen05MMABlockScaleOp::verify() {
5320 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5321 getBlockScale(), getLoc());
5322}
5323
5324//===----------------------------------------------------------------------===//
5325// NVVM tcgen05.mma.sp.block_scale functions
5326//===----------------------------------------------------------------------===//
5327
5328mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
5329 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5330
5331 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5333
5334 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5335
5336 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5337 bool isATensor = isa<llvm::PointerType>(A->getType());
5338 args.push_back(A);
5339
5340 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5341 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5342 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5343 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5344 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5345 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5346 args.push_back(builder.getInt32(
5347 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5348 args.push_back(
5349 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5350
5351 auto kind = thisOp.getKind();
5352 auto blockScale = thisOp.getBlockScale();
5353 llvm::Intrinsic::ID ID = [&]() {
5354 if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
5355 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5356 return isATensor ? llvm::Intrinsic::
5357 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5358 : llvm::Intrinsic::
5359 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5360 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5361 return isATensor
5362 ? llvm::Intrinsic::
5363 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5364 : llvm::Intrinsic::
5365 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5366 }
5367 } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
5368 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5369 return isATensor ? llvm::Intrinsic::
5370 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5371 : llvm::Intrinsic::
5372 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5373 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5374 return isATensor
5375 ? llvm::Intrinsic::
5376 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5377 : llvm::Intrinsic::
5378 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5379 }
5380 } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
5381 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5382 return isATensor
5383 ? llvm::Intrinsic::
5384 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5385 : llvm::Intrinsic::
5386 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5387
5388 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5389 return isATensor
5390 ? llvm::Intrinsic::
5391 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5392 : llvm::Intrinsic::
5393 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5394 }
5395 }
5396 llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
5397 }();
5398
5399 return {ID, args};
5400}
5401
5402LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5403 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5404 getBlockScale(), getLoc());
5405}
5406
5407//===----------------------------------------------------------------------===//
5408// NVVM tcgen05.mma.ws functions
5409//===----------------------------------------------------------------------===//
5410
5411mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
5412 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5413
5414 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5416
5417 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5418
5419 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5420 bool isATensor = isa<llvm::PointerType>(A->getType());
5421 args.push_back(A);
5422
5423 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5424 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5425 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5426
5427 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5428 llvm::Intrinsic::ID ID = notIntrinsic;
5429 if (ZeroColMask) {
5430 args.push_back(mt.lookupValue(ZeroColMask));
5431 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
5432 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
5433 } else
5434 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
5435 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
5436
5437 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5438 args.push_back(
5439 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5440 args.push_back(
5441 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5442
5443 return {ID, args};
5444}
5445
5446//===----------------------------------------------------------------------===//
5447// NVVM tcgen05.mma.ws.sp functions
5448//===----------------------------------------------------------------------===//
5449
5450mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
5451 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5452
5453 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
5455
5456 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5457
5458 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5459 bool isATensor = isa<llvm::PointerType>(A->getType());
5460 args.push_back(A);
5461
5462 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5463 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5464 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5465 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5466
5467 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5468 llvm::Intrinsic::ID ID = notIntrinsic;
5469 if (ZeroColMask) {
5470 args.push_back(mt.lookupValue(ZeroColMask));
5471 ID = isATensor
5472 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
5473 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
5474 } else
5475 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
5476 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
5477
5478 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5479 args.push_back(
5480 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5481 args.push_back(
5482 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5483
5484 return {ID, args};
5485}
5486
5487//===----------------------------------------------------------------------===//
5488// NVVM tcgen05.ld.red functions
5489//===----------------------------------------------------------------------===//
5490
5491#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
5492 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
5493
5494mlir::NVVM::IDArgPair NVVM::Tcgen05LdRedOp::getIntrinsicIDAndArgs(
5495 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5496 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
5498
5499 mlir::VectorType VecResTy =
5500 cast<mlir::VectorType>(thisOp.getData().getType());
5501 unsigned Num = VecResTy.getNumElements();
5502 bool IsFloat = thisOp.getRedVal().getType().isF32();
5503
5504 llvm::Intrinsic::ID Shape32x32b[][2] = {
5506 {TCGEN05LDRED(32x32b, x2, i32), TCGEN05LDRED(32x32b, x2, f32)},
5507 {TCGEN05LDRED(32x32b, x4, i32), TCGEN05LDRED(32x32b, x4, f32)},
5508 {TCGEN05LDRED(32x32b, x8, i32), TCGEN05LDRED(32x32b, x8, f32)},
5509 {TCGEN05LDRED(32x32b, x16, i32), TCGEN05LDRED(32x32b, x16, f32)},
5510 {TCGEN05LDRED(32x32b, x32, i32), TCGEN05LDRED(32x32b, x32, f32)},
5511 {TCGEN05LDRED(32x32b, x64, i32), TCGEN05LDRED(32x32b, x64, f32)},
5512 {TCGEN05LDRED(32x32b, x128, i32), TCGEN05LDRED(32x32b, x128, f32)},
5513 };
5514
5515 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
5517 {TCGEN05LDRED(16x32bx2, x2, i32), TCGEN05LDRED(16x32bx2, x2, f32)},
5518 {TCGEN05LDRED(16x32bx2, x4, i32), TCGEN05LDRED(16x32bx2, x4, f32)},
5519 {TCGEN05LDRED(16x32bx2, x8, i32), TCGEN05LDRED(16x32bx2, x8, f32)},
5520 {TCGEN05LDRED(16x32bx2, x16, i32), TCGEN05LDRED(16x32bx2, x16, f32)},
5521 {TCGEN05LDRED(16x32bx2, x32, i32), TCGEN05LDRED(16x32bx2, x32, f32)},
5522 {TCGEN05LDRED(16x32bx2, x64, i32), TCGEN05LDRED(16x32bx2, x64, f32)},
5523 {TCGEN05LDRED(16x32bx2, x128, i32), TCGEN05LDRED(16x32bx2, x128, f32)},
5524 };
5525
5526 NVVM::Tcgen05LdStShape shape = thisOp.getShape();
5527 unsigned ID = [&]() {
5528 // `num` contains the length of vector and log2 of `num` returns the index
5529 // into the shape array
5530 unsigned idx = std::log2(Num);
5531 switch (shape) {
5532 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
5533 return Shape32x32b[idx][IsFloat];
5534 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
5535 return Shape16x32bx2[idx][IsFloat];
5536 default:
5537 llvm_unreachable("unhandled tcgen05.ld lowering");
5538 }
5539 }();
5540
5541 args.push_back(mt.lookupValue(thisOp.getAddr()));
5542
5543 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
5544 args.push_back(mt.lookupValue(thisOp.getOffset()));
5545
5546 args.push_back(
5547 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
5548
5549 if (IsFloat) {
5550 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getAbs())));
5551 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getNan())));
5552 }
5553 return {ID, args};
5554}
5555
5556LogicalResult Tcgen05LdRedOp::verify() {
5557 VectorType data = cast<VectorType>(getData().getType());
5558 Type redVal = getRedVal().getType();
5559
5560 if (data.getElementType() != redVal)
5561 return emitError(
5562 "type of reduction value and element type of vector data should match");
5563
5564 if (getOp() != NVVM::ReductionKind::MIN &&
5565 getOp() != NVVM::ReductionKind::MAX)
5566 return emitError("only min and max reduction kinds are supported");
5567
5568 if (redVal.isInteger() && (getAbs() || getNan())) {
5569 return emitError("abs or nan is only applicable for f32 type");
5570 }
5571 return success();
5572}
5573
5574//===----------------------------------------------------------------------===//
5575// NVVMDialect initialization, type parsing, and registration.
5576//===----------------------------------------------------------------------===//
5577
5578// TODO: This should be the llvm.nvvm dialect once this is supported.
5579void NVVMDialect::initialize() {
5580 addOperations<
5581#define GET_OP_LIST
5582#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5583 >();
5584 addAttributes<
5585#define GET_ATTRDEF_LIST
5586#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
5587 >();
5588
5589 // Support unknown operations because not all NVVM operations are
5590 // registered.
5591 allowUnknownOperations();
5592 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
5593 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
5594}
5595
5596LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
5597 NamedAttribute attr) {
5598 StringAttr attrName = attr.getName();
5599 // Kernel function attribute should be attached to functions.
5600 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
5601 if (!isa<LLVM::LLVMFuncOp>(op)) {
5602 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
5603 << "' attribute attached to unexpected op";
5604 }
5605 }
5606 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
5607 // dim
5608 if (attrName == NVVMDialect::getMaxntidAttrName() ||
5609 attrName == NVVMDialect::getReqntidAttrName() ||
5610 attrName == NVVMDialect::getClusterDimAttrName()) {
5611 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
5612 if (!values || values.empty() || values.size() > 3) {
5613 return op->emitError()
5614 << "'" << attrName
5615 << "' attribute must be integer array with maximum 3 index";
5616 }
5617 }
5618 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
5619 // attribute
5620 if (attrName == NVVMDialect::getMinctasmAttrName() ||
5621 attrName == NVVMDialect::getMaxnregAttrName() ||
5622 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
5623 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
5624 return op->emitError()
5625 << "'" << attrName << "' attribute must be integer constant";
5626 }
5627 }
5628 // blocksareclusters must be used along with reqntid and cluster_dim
5629 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
5630 if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
5631 !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
5632 return op->emitError()
5633 << "'" << attrName << "' attribute must be used along with "
5634 << "'" << NVVMDialect::getReqntidAttrName() << "' and "
5635 << "'" << NVVMDialect::getClusterDimAttrName() << "'";
5636 }
5637 }
5638
5639 return success();
5640}
5641
5642LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
5643 unsigned regionIndex,
5644 unsigned argIndex,
5645 NamedAttribute argAttr) {
5646 auto funcOp = dyn_cast<FunctionOpInterface>(op);
5647 if (!funcOp)
5648 return success();
5649
5650 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
5651 StringAttr attrName = argAttr.getName();
5652 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
5653 if (!isKernel) {
5654 return op->emitError()
5655 << "'" << attrName
5656 << "' attribute must be present only on kernel arguments";
5657 }
5658 if (!isa<UnitAttr>(argAttr.getValue()))
5659 return op->emitError() << "'" << attrName << "' must be a unit attribute";
5660 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
5661 return op->emitError()
5662 << "'" << attrName
5663 << "' attribute requires the argument to also have attribute '"
5664 << LLVM::LLVMDialect::getByValAttrName() << "'";
5665 }
5666 }
5667
5668 return success();
5669}
5670
5671//===----------------------------------------------------------------------===//
5672// NVVM Address Space Attr
5673//===----------------------------------------------------------------------===//
5674
5675unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
5676 return static_cast<unsigned>(getValue());
5677}
5678
5679bool NVVMMemorySpaceAttr::isValidLoad(
5680 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5681 const ::mlir::DataLayout *dataLayout,
5683 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
5684 dataLayout, emitError);
5685}
5686
5687bool NVVMMemorySpaceAttr::isValidStore(
5688 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5689 const ::mlir::DataLayout *dataLayout,
5691 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
5692 dataLayout, emitError);
5693}
5694
5695bool NVVMMemorySpaceAttr::isValidAtomicOp(
5696 ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
5697 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
5699 // TODO: update this method once `ptr.atomic_rmw` is implemented.
5700 assert(false && "unimplemented, see TODO in the source.");
5701 return false;
5702}
5703
5704bool NVVMMemorySpaceAttr::isValidAtomicXchg(
5705 Type type, ptr::AtomicOrdering successOrdering,
5706 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
5707 const ::mlir::DataLayout *dataLayout,
5709 // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
5710 assert(false && "unimplemented, see TODO in the source.");
5711 return false;
5712}
5713
5714bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
5715 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
5716 // TODO: update this method once the `ptr.addrspace_cast` op is added to the
5717 // dialect.
5718 assert(false && "unimplemented, see TODO in the source.");
5719 return false;
5720}
5721
5722bool NVVMMemorySpaceAttr::isValidPtrIntCast(
5723 Type intLikeTy, Type ptrLikeTy,
5725 // TODO: update this method once the int-cast ops are added to the `ptr`
5726 // dialect.
5727 assert(false && "unimplemented, see TODO in the source.");
5728 return false;
5729}
5730
5731//===----------------------------------------------------------------------===//
5732// NVVM target attribute.
5733//===----------------------------------------------------------------------===//
5734LogicalResult
5735NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
5736 int optLevel, StringRef triple, StringRef chip,
5737 StringRef features, DictionaryAttr flags,
5738 ArrayAttr files, bool verifyTarget) {
5739 if (optLevel < 0 || optLevel > 3) {
5740 emitError() << "The optimization level must be a number between 0 and 3.";
5741 return failure();
5742 }
5743 if (triple.empty()) {
5744 emitError() << "The target triple cannot be empty.";
5745 return failure();
5746 }
5747 if (chip.empty()) {
5748 emitError() << "The target chip cannot be empty.";
5749 return failure();
5750 }
5751 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
5752 return mlir::isa_and_nonnull<StringAttr>(attr);
5753 })) {
5754 emitError() << "All the elements in the `link` array must be strings.";
5755 return failure();
5756 }
5757 return success();
5758}
5759
5760LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
5761 if (!getVerifyTarget())
5762 return success();
5763
5764 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
5765 if (!gpuModuleOp) {
5766 return emitError(gpuModule->getLoc(),
5767 "NVVM target attribute must be attached to a GPU module");
5768 }
5769
5770 const NVVMCheckSMVersion targetSMVersion =
5772 if (!targetSMVersion.isMinimumSMVersion()) {
5773 return emitError(gpuModule->getLoc(),
5774 "Minimum NVVM target SM version is sm_20");
5775 }
5776
5777 if (gpuModuleOp
5778 ->walk([&](Operation *op) {
5779 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
5780 const NVVMCheckSMVersion requirement =
5781 reqOp.getRequiredMinSMVersion();
5782 if (!requirement.isCompatibleWith(targetSMVersion)) {
5783 op->emitOpError() << "is not supported on " << getChip();
5784 return WalkResult::interrupt();
5785 }
5786 }
5787 return WalkResult::advance();
5788 })
5789 .wasInterrupted())
5790 return failure();
5791
5792 return success();
5793}
5794
5795#define GET_OP_CLASSES
5796#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5797
5798#define GET_ATTRDEF_CLASSES
5799#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
#define _none
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &regs)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > &regTypes, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerType getI16Type()
Definition Builders.cpp:65
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition NVVMDialect.h:53
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
This represents an operation in an abstracted form, suitable for use with the builder APIs.