MLIR 23.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
39#include <cassert>
40#include <optional>
41#include <string>
42
43using namespace mlir;
44using namespace NVVM;
45
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
50
51//===----------------------------------------------------------------------===//
52// Helper/Utility methods
53//===----------------------------------------------------------------------===//
54
55static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
57 return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
58}
59
61 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
62}
63
65 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
66}
67
69 return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
70}
71
72static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
73 llvm::Value *ptr,
74 NVVMMemorySpace targetAS) {
75 unsigned AS = static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
78}
79
80// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
81static llvm::nvvm::CTAGroupKind
82getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
83 switch (ctaGroup) {
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
88 }
89 llvm_unreachable("unsupported cta_group value");
90}
91
92//===----------------------------------------------------------------------===//
93// Verifier methods
94//===----------------------------------------------------------------------===//
95
96// This verifier is shared among the following Ops:
97// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
98// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
99static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
100 bool isIm2Col,
101 size_t numIm2ColOffsets,
102 Location loc) {
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc, "expects coordinates between 1 to 5 dimension");
105
106 // For Im2Col mode, there are two constraints:
107 if (isIm2Col) {
108 // 1. Tensor must always be at least 3-d.
109 if (tensorDims < 3)
110 return emitError(
111 loc,
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
113 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
115 return emitError(
116 loc, "im2col offsets must be 2 less than number of coordinates");
117 }
118 return success();
119}
120
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
123 // We lower through inline-ptx when getPredicate() is true.
124 // a) Only TILE mode is supported
125 // b) Cache-hint is not supported
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError("Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
131 }
132
133 size_t dims = getCoordinates().size();
134 switch (mode) {
135 case TMAStoreMode::TILE:
136 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
137 case TMAStoreMode::IM2COL:
138 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
139 case TMAStoreMode::TILE_SCATTER4:
140 if (dims != 5)
141 return emitError("Scatter4 mode expects 5 coordinates");
142 }
143 return success();
144}
145
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError("Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError("expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError("CG cache modifier is only support for 16 bytes copy.");
154 return success();
155}
156
157// This verify params can be shared across TMA Load and Prefetch Ops.
158static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
159 TMALoadMode mode, Location loc) {
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc, "expects coordinates between 1 to 5 dimension");
162
163 auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
166 return emitError(loc)
167 << "to use " << mode
168 << " mode, the tensor has to be at least 3-dimensional";
169
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
172 << " (provided " << numIm2colOff << ")";
173
174 return success();
175 };
176
177 switch (mode) {
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode, false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode, true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode, true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode, false, 0)
188 : emitError(loc, "Gather4 mode expects 5 coordinates");
189 }
190 return success();
191}
192
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
194 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
195 getMode(), getLoc());
196}
197
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) { // Inline-asm based lowering
202 if (isCTAOnly)
203 return emitError("Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
205 return emitError(
206 "Predicate is supported only for Tile and Im2col modes.");
207 } else { // Intrinsics-based lowering
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
211 .getAddressSpace();
212 if (AS != expectedAS)
213 return emitError()
214 << (isCTAOnly
215 ? "Shared::cta destination requires address-space 3."
216 : "Shared::cluster destination requires address-space 7.");
217 // Checks specific to shared::cta mode
218 if (isCTAOnly) {
219 if (getMulticastMask())
220 return emitError("Multicast is not supported with shared::cta mode.");
221 if (getGroup())
222 return emitError("CTAGroup is not supported with shared::cta mode.");
223 }
224 }
225
226 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
227 getMode(), getLoc());
228}
229
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
232 size_t dims = getCoordinates().size();
233 switch (mode) {
234 case TMAStoreMode::TILE:
235 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
236 case TMAStoreMode::IM2COL:
237 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
240 }
241 return success();
242}
243
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
245 bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
246 if (isSharedCTA && getMulticastMask())
247 return emitError("Multicast is not supported with shared::cta mode.");
248
249 return success();
250}
251
252static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
253 NVVM::MemScopeKind scope,
254 Value retVal = nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->emitError("mbarrier scope must be either CTA or Cluster");
257
258 bool isSharedCluster = isPtrInSharedClusterSpace(addr);
259 bool hasRetValue = static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
261 return op->emitError(
262 "mbarrier in shared_cluster space cannot return any value");
263
264 return success();
265}
266
267LogicalResult MBarrierArriveOp::verify() {
268 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
269 getRes());
270}
271
272LogicalResult MBarrierArriveDropOp::verify() {
273 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
274 getRes());
275}
276
277LogicalResult MBarrierArriveExpectTxOp::verify() {
278 // The inline-ptx version of this Op does not support all features.
279 // With predicate, this Op lowers to inline-ptx. So, verify and
280 // error-out if there are unsupported features.
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError("mbarrier scope must be CTA when using predicate");
284
285 if (isPtrInSharedClusterSpace(getAddr()))
286 return emitError("mbarrier in shared_cluster space is not supported when "
287 "using predicate");
288
289 if (getRes())
290 return emitError("return-value is not supported when using predicate");
291
292 if (getRelaxed() == true)
293 return emitError("mbarrier with relaxed semantics is not supported when "
294 "using predicate");
295 }
296 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
297 getRes());
298}
299
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
301 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
302 getRes());
303}
304
305LogicalResult MBarrierExpectTxOp::verify() {
306 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
307}
308
309LogicalResult MBarrierCompleteTxOp::verify() {
310 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
311}
312
313LogicalResult MBarrierTestWaitOp::verify() {
314 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
315}
316
317LogicalResult MBarrierTryWaitOp::verify() {
318 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
319}
320
321LogicalResult ConvertFloatToTF32Op::verify() {
322 using RndMode = NVVM::FPRoundingMode;
323 switch (getRnd()) {
324 case RndMode::RNA:
325 if (getRelu())
326 return emitError("Relu not supported with rna rounding mode.");
327 break;
328 case RndMode::RN:
329 case RndMode::RZ:
330 break;
331 default:
332 return emitError(
333 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
334 }
335 return success();
336}
337
338LogicalResult ConvertF32x2ToF6x2Op::verify() {
340
341 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
342 return emitOpError("Only ")
343 << mlir::Float6E2M3FNType::get(ctx) << " and "
344 << mlir::Float6E3M2FNType::get(ctx)
345 << " types are supported for conversions from f32x2 to f6x2.";
346 }
347 return success();
348}
349
350LogicalResult ConvertF32x2ToF8x2Op::verify() {
351 using RndMode = NVVM::FPRoundingMode;
352 using SatMode = NVVM::SaturationMode;
353
354 bool isRoundingModeRN = getRnd() == RndMode::RN;
355 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
356 bool isRoundingModeRP = getRnd() == RndMode::RP;
357 bool isSatFinite = getSat() == SatMode::SATFINITE;
358
359 bool hasRelu = getRelu();
360
362
364 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
365 [&](mlir::Type) -> LogicalResult {
366 if (!isRoundingModeRN) {
367 return emitOpError("Only RN rounding mode is supported for "
368 "conversions from f32x2 to ")
369 << mlir::Float8E4M3FNType::get(ctx) << " and "
370 << mlir::Float8E5M2Type::get(ctx) << " types";
371 }
372 if (!isSatFinite) {
373 return emitOpError("Only SATFINITE saturation mode is supported "
374 "for conversions "
375 "from f32x2 to ")
376 << mlir::Float8E4M3FNType::get(ctx) << " and "
377 << mlir::Float8E5M2Type::get(ctx) << " types";
378 }
379 return success();
380 })
381 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
382 if (!(isRoundingModeRZ || isRoundingModeRP)) {
383 return emitOpError("Only RZ and RP rounding modes are supported for "
384 "conversions from f32x2 to ")
385 << mlir::Float8E8M0FNUType::get(ctx) << " type";
386 }
387 if (hasRelu) {
388 return emitOpError("relu not supported for conversions to ")
389 << mlir::Float8E8M0FNUType::get(ctx) << " type";
390 }
391 return success();
392 })
393 .Default([&](mlir::Type) {
394 return emitOpError("Only ")
395 << mlir::Float8E4M3FNType::get(ctx) << ", "
396 << mlir::Float8E5M2Type::get(ctx) << ", and "
397 << mlir::Float8E8M0FNUType::get(ctx)
398 << " types are "
399 "supported for conversions from f32x2 to f8x2";
400 });
401}
402
403LogicalResult ConvertF16x2ToF8x2Op::verify() {
405
406 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
407 return emitOpError("Only ")
408 << mlir::Float8E4M3FNType::get(ctx) << " and "
409 << mlir::Float8E5M2Type::get(ctx)
410 << " types are supported for conversions from f16x2 to f8x2.";
411 }
412 return success();
413}
414
415LogicalResult ConvertBF16x2ToF8x2Op::verify() {
416 using RndMode = NVVM::FPRoundingMode;
417
418 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
419 return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
420 << " type is supported for conversions from "
421 "bf16x2 to f8x2.";
422
423 auto rnd = getRnd();
424 if (rnd != RndMode::RZ && rnd != RndMode::RP)
425 return emitOpError("Only RZ and RP rounding modes are supported for "
426 "conversions from bf16x2 to f8x2.");
427
428 return success();
429}
430
431LogicalResult ConvertF32x2ToF4x2Op::verify() {
433
434 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
435 return emitOpError("Only ")
436 << mlir::Float4E2M1FNType::get(ctx)
437 << " type is supported for conversions from f32x2 to f4x2.";
438
439 return success();
440}
441
442LogicalResult ConvertF8x2ToF16x2Op::verify() {
444
445 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
446 return emitOpError("Only ")
447 << mlir::Float8E4M3FNType::get(ctx) << " and "
448 << mlir::Float8E5M2Type::get(ctx)
449 << " types are supported for conversions from f8x2 to f16x2.";
450
451 return success();
452}
453
454LogicalResult ConvertF8x2ToBF16x2Op::verify() {
456 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
457 return emitOpError("Only ")
458 << mlir::Float8E8M0FNUType::get(ctx)
459 << " type is supported for conversions from f8x2 to bf16x2.";
460
461 return success();
462}
463
464LogicalResult ConvertF6x2ToF16x2Op::verify() {
466
467 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
468 return emitOpError("Only ")
469 << mlir::Float6E2M3FNType::get(ctx) << " and "
470 << mlir::Float6E3M2FNType::get(ctx)
471 << " types are supported for conversions from f6x2 to f16x2.";
472
473 return success();
474}
475
476LogicalResult ConvertF4x2ToF16x2Op::verify() {
478
479 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
480 return emitOpError("Only ")
481 << mlir::Float4E2M1FNType::get(ctx)
482 << " type is supported for conversions from f4x2 to f16x2.";
483
484 return success();
485}
486
487LogicalResult PermuteOp::verify() {
488 using Mode = NVVM::PermuteMode;
489 bool hasHi = static_cast<bool>(getHi());
490
491 switch (getMode()) {
492 case Mode::DEFAULT:
493 case Mode::F4E:
494 case Mode::B4E:
495 if (!hasHi)
496 return emitError("mode '") << getMode() << "' requires 'hi' operand.";
497 break;
498 case Mode::RC8:
499 case Mode::ECL:
500 case Mode::ECR:
501 case Mode::RC16:
502 if (hasHi)
503 return emitError("mode '")
504 << getMode() << "' does not accept 'hi' operand.";
505 break;
506 }
507
508 return success();
509}
510
511//===----------------------------------------------------------------------===//
512// Stochastic Rounding Conversion Ops
513//===----------------------------------------------------------------------===//
514
515static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
516 FPRoundingMode rnd,
517 bool hasRandomBits,
518 Operation *op) {
519 static constexpr FPRoundingMode validRndModes[] = {
520 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
521
522 if (!llvm::is_contained(validRndModes, rnd)) {
523 return op->emitOpError(
524 "Only RN, RZ, and RS rounding modes are supported for "
525 "conversions from f32x2 to ")
526 << dstType << ".";
527 }
528
529 if (rnd == FPRoundingMode::RS) {
530 if (!hasRandomBits) {
531 return op->emitOpError("random_bits is required for RS rounding mode.");
532 }
533 } else {
534 if (hasRandomBits) {
535 return op->emitOpError(
536 "random_bits not supported for RN and RZ rounding modes.");
537 }
538 }
539
540 return success();
541}
542
543LogicalResult ConvertF32x2ToF16x2Op::verify() {
544 return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
545 getRandomBits() ? true : false, *this);
546}
547
548LogicalResult ConvertF32x2ToBF16x2Op::verify() {
549 return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
550 getRandomBits() ? true : false, *this);
551}
552
553LogicalResult ConvertF32x4ToF8x4Op::verify() {
555
556 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
557 return emitOpError("Only ")
558 << mlir::Float8E4M3FNType::get(ctx) << " and "
559 << mlir::Float8E5M2Type::get(ctx)
560 << " types are supported for conversions from f32x4 to f8x4.";
561
562 return success();
563}
564
565LogicalResult ConvertF32x4ToF6x4Op::verify() {
567
568 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
569 return emitOpError("Only ")
570 << mlir::Float6E2M3FNType::get(ctx) << " and "
571 << mlir::Float6E3M2FNType::get(ctx)
572 << " types are supported for conversions from f32x4 to f6x4.";
573
574 return success();
575}
576
577LogicalResult ConvertF32x4ToF4x4Op::verify() {
579
580 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
581 return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
582 << " type is supported for conversions from "
583 "f32x4 to f4x4.";
584
585 return success();
586}
587
588LogicalResult BulkStoreOp::verify() {
589 if (getInitVal() != 0)
590 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
591 return success();
592}
593
594LogicalResult PMEventOp::verify() {
595 auto eventId = getEventId();
596 auto maskedEventId = getMaskedEventId();
597 if (!maskedEventId && !eventId) {
598 return emitOpError() << "either `id` or `mask` must be set";
599 }
600
601 if (maskedEventId && eventId) {
602 return emitOpError() << "`id` and `mask` cannot be set at the same time";
603 }
604
605 if (eventId) {
606 if (eventId < 0 || eventId > 15) {
607 return emitOpError() << "`id` must be between 0 and 15";
608 }
609 }
610
611 return llvm::success();
612}
613
614// Given the element type of an operand and whether or not it is an accumulator,
615// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
616// operand's element type.
617std::optional<mlir::NVVM::MMATypes>
618MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
619 auto half2Type =
620 VectorType::get(2, Float16Type::get(operandElType.getContext()));
621 if (operandElType.isF64())
622 return NVVM::MMATypes::f64;
623 if (operandElType.isF16() || operandElType == half2Type)
624 return NVVM::MMATypes::f16;
625 if (operandElType.isF32() && isAccumulator)
626 return NVVM::MMATypes::f32;
627 if (operandElType.isF32() && !isAccumulator)
628 return NVVM::MMATypes::tf32;
629 if (llvm::isa<IntegerType>(operandElType)) {
630 if (isAccumulator)
631 return NVVM::MMATypes::s32;
632 return std::nullopt;
633 }
634
635 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
636 if (structType.getBody().empty())
637 return std::nullopt;
638 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
639 }
640
641 return std::nullopt;
642}
643
644static bool isInt4PtxType(MMATypes type) {
645 return (type == MMATypes::u4 || type == MMATypes::s4);
646}
647
648static bool isInt8PtxType(MMATypes type) {
649 return (type == MMATypes::u8 || type == MMATypes::s8);
650}
651
652static bool isIntegerPtxType(MMATypes type) {
653 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
654 type == MMATypes::s32;
655}
656
657MMATypes MmaOp::accumPtxType() {
658 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
659 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
660 assert(val.has_value() && "accumulator PTX type should always be inferrable");
661 return val.value();
662}
663
664MMATypes MmaOp::resultPtxType() {
665 std::optional<mlir::NVVM::MMATypes> val =
666 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
667 assert(val.has_value() && "result PTX type should always be inferrable");
668 return val.value();
669}
670
671void MmaOp::print(OpAsmPrinter &p) {
672 SmallVector<Type, 4> regTypes;
673 struct MMAOperandFragment {
674 StringRef operandName;
675 StringRef ptxTypeAttr;
676 SmallVector<Value, 4> regs;
677 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
678 : operandName(name), ptxTypeAttr(ptxTypeName) {}
679 };
680
681 std::array<MMAOperandFragment, 3> frags{
682 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
683 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
684 MMAOperandFragment("C", "")};
685 SmallVector<StringRef, 4> ignoreAttrNames{
686 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
687
688 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
689 auto &frag = frags[fragIdx];
690 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
691 for (auto operandIdx = varOperandSpec.first;
692 operandIdx < varOperandSpec.first + varOperandSpec.second;
693 operandIdx++) {
694 frag.regs.push_back(this->getOperand(operandIdx));
695 if (operandIdx == 0) {
696 regTypes.push_back(this->getOperand(operandIdx).getType());
697 }
698 }
699 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
700 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
701 if (inferredType)
702 ignoreAttrNames.push_back(frag.ptxTypeAttr);
703 }
704
705 auto printMmaOperand = [&](const MMAOperandFragment &frag) -> void {
706 p << " " << frag.operandName;
707 p << "[";
708 p.printOperands(frag.regs);
709 p << "] ";
710 };
711
712 for (const auto &frag : frags) {
713 printMmaOperand(frag);
714 }
715
716 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
717
718 // Print the types of the operands and result.
719 p << " : "
720 << "(";
721 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
722 frags[1].regs[0].getType(),
723 frags[2].regs[0].getType()},
724 p);
725 p << ")";
726 p.printArrowTypeList(TypeRange{this->getRes().getType()});
727}
728
729void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
730 ValueRange operandA, ValueRange operandB, ValueRange operandC,
731 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
732 std::optional<MMAIntOverflow> intOverflow,
733 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
734 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
735
736 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
737 MLIRContext *ctx = builder.getContext();
738 result.addAttribute(
739 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
740
741 result.addOperands(operandA);
742 result.addOperands(operandB);
743 result.addOperands(operandC);
744
745 if (multiplicandPtxTypes) {
746 result.addAttribute("multiplicandAPtxType",
747 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
748 result.addAttribute("multiplicandBPtxType",
749 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
750 } else {
751 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
752 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
753 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
754 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
755 }
756
757 if (multiplicandLayouts) {
758 result.addAttribute("layoutA",
759 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
760 result.addAttribute("layoutB",
761 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
762 } else {
763 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
764 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
765 }
766
767 if (intOverflow.has_value())
768 result.addAttribute("intOverflowBehavior",
769 MMAIntOverflowAttr::get(ctx, *intOverflow));
770 if (b1Op.has_value())
771 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
772
773 result.addTypes(resultType);
774 result.addAttribute(
775 MmaOp::getOperandSegmentSizeAttr(),
776 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
777 static_cast<int32_t>(operandB.size()),
778 static_cast<int32_t>(operandC.size())}));
779}
780
781// <operation> :=
782// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
783// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
784// `->` type($res)
785ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
786 struct MMAOperandFragment {
787 std::optional<MMATypes> elemtype;
788 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
789 SmallVector<Type> regTypes;
790 };
791
792 Builder &builder = parser.getBuilder();
793 std::array<MMAOperandFragment, 4> frags;
794
795 NamedAttrList namedAttributes;
796
797 // A helper to parse the operand segments.
798 auto parseMmaOperand = [&](StringRef operandName,
799 MMAOperandFragment &frag) -> LogicalResult {
800 if (parser.parseKeyword(operandName).failed())
801 return failure();
802 if (parser
803 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
804 .failed())
805 return failure();
806 return success();
807 };
808
809 // Parse the operand segments.
810 if (parseMmaOperand("A", frags[0]).failed())
811 return failure();
812 if (parseMmaOperand("B", frags[1]).failed())
813 return failure();
814 if (parseMmaOperand("C", frags[2]).failed())
815 return failure();
816
817 if (parser.parseOptionalAttrDict(namedAttributes).failed())
818 return failure();
819
820 // Parse the type specification and resolve operands.
821 SmallVector<Type, 3> operandTypes;
822 if (failed(parser.parseColon()))
823 return failure();
824 if (failed(parser.parseLParen()))
825 return failure();
826 if (failed(parser.parseTypeList(operandTypes)))
827 return failure();
828 if (failed(parser.parseRParen()))
829 if (operandTypes.size() != 3)
830 return parser.emitError(
831 parser.getNameLoc(),
832 "expected one type for each operand segment but got " +
833 Twine(operandTypes.size()) + " types");
834 for (const auto &iter : llvm::enumerate(operandTypes)) {
835 auto &frag = frags[iter.index()];
836 frag.regTypes.resize(frag.regs.size(), iter.value());
837 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
838 parser.getNameLoc(), result.operands)))
839 return failure();
840 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
841 /*isAccumulator*/ iter.index() < 2);
842 }
843
844 Type resultType;
845 if (parser.parseArrow() || parser.parseType(resultType))
846 return failure();
847 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
848
849 std::array<StringRef, 2> names{"multiplicandAPtxType",
850 "multiplicandBPtxType"};
851 for (unsigned idx = 0; idx < names.size(); idx++) {
852 const auto &frag = frags[idx];
853 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
854 if (!frag.elemtype.has_value() && !attr.has_value()) {
855 return parser.emitError(
856 parser.getNameLoc(),
857 "attribute " + names[idx] +
858 " is not provided explicitly and cannot be inferred");
859 }
860 if (!attr.has_value())
861 result.addAttribute(
862 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
863 }
864
865 result.addTypes(resultType);
866 if (!namedAttributes.empty())
867 result.addAttributes(namedAttributes);
868 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
869 builder.getDenseI32ArrayAttr({
870 static_cast<int32_t>(frags[0].regs.size()),
871 static_cast<int32_t>(frags[1].regs.size()),
872 static_cast<int32_t>(frags[2].regs.size()),
873 }));
874 return success();
875}
876
877LogicalResult MmaOp::verify() {
878 MLIRContext *context = getContext();
879 auto f16Ty = Float16Type::get(context);
880 auto i32Ty = IntegerType::get(context, 32);
881 auto f16x2Ty = VectorType::get(2, f16Ty);
882 auto f32Ty = Float32Type::get(context);
883 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
884 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
885
886 auto s32x4StructTy =
887 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
888 auto f32x8StructTy =
889 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
890 auto f16x2x2StructTy =
891 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
892 auto f32x4StructTy =
893 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
894 auto s32x2StructTy =
895 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
896
897 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
898 getShapeAttr().getK()};
899
900 // These variables define the set of allowed data types for matrices A, B, C,
901 // and result.
902 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
903 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
904 AllowedShapes allowedShapes;
905 AllowedTypes expectedA;
906 AllowedTypes expectedB;
907 AllowedTypes expectedC;
908 SmallVector<Type> expectedResult;
909
910 // When M = 16, we just need to calculate the number of 8xk tiles, where
911 // k is a factor that depends on the data type.
912 if (mmaShape[0] == 16) {
913 int64_t kFactor;
914 Type multiplicandFragType;
915 switch (*getMultiplicandAPtxType()) {
916 case MMATypes::tf32:
917 kFactor = 4;
918 multiplicandFragType = i32Ty;
919 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
920 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
921 break;
922 case MMATypes::bf16:
923 kFactor = 8;
924 multiplicandFragType = i32Ty;
925 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
926 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
927 break;
928 case MMATypes::f16:
929 kFactor = 8;
930 multiplicandFragType = f16x2Ty;
931 expectedResult.push_back(f16x2x2StructTy);
932 expectedResult.push_back(f32x4StructTy);
933 break;
934 case MMATypes::s4:
935 case MMATypes::u4:
936 kFactor = 32;
937 break;
938 case MMATypes::b1:
939 kFactor = 128;
940 break;
941 case MMATypes::s8:
942 case MMATypes::u8:
943 kFactor = 16;
944 break;
945 default:
946 return emitError("invalid shape or multiplicand type: ")
947 << getMultiplicandAPtxType().value();
948 }
949
950 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
951 expectedResult.push_back(s32x4StructTy);
952 expectedC.emplace_back(4, i32Ty);
953 multiplicandFragType = i32Ty;
954 } else {
955 expectedC.emplace_back(2, f16x2Ty);
956 expectedC.emplace_back(4, f32Ty);
957 }
958
959 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
960 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
961 expectedA.emplace_back(unitA, multiplicandFragType);
962 expectedB.emplace_back(unitB, multiplicandFragType);
963 allowedShapes.push_back({16, 8, kFactor});
964 allowedShapes.push_back({16, 8, kFactor * 2});
965
966 if (resultPtxType() != accumPtxType())
967 return emitOpError("ctype does not match dtype");
968 }
969
970 // In the M=8 case, there is only 1 possible case per data type.
971 if (mmaShape[0] == 8) {
972 if (*getMultiplicandAPtxType() == MMATypes::f16) {
973 expectedA.emplace_back(2, f16x2Ty);
974 expectedB.emplace_back(2, f16x2Ty);
975 expectedResult.push_back(f16x2x4StructTy);
976 expectedResult.push_back(f32x8StructTy);
977 expectedC.emplace_back(4, f16x2Ty);
978 expectedC.emplace_back(8, f32Ty);
979 allowedShapes.push_back({8, 8, 4});
980 }
981 if (*getMultiplicandAPtxType() == MMATypes::f64) {
982 Type f64Ty = Float64Type::get(context);
983 expectedA.emplace_back(1, f64Ty);
984 expectedB.emplace_back(1, f64Ty);
985 expectedC.emplace_back(2, f64Ty);
986 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
987 context, SmallVector<Type>(2, f64Ty)));
988 allowedShapes.push_back({8, 8, 4});
989 }
990 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
991 expectedA.push_back({i32Ty});
992 expectedB.push_back({i32Ty});
993 expectedC.push_back({i32Ty, i32Ty});
994 expectedResult.push_back(s32x2StructTy);
995 if (isInt4PtxType(getMultiplicandAPtxType().value()))
996 allowedShapes.push_back({8, 8, 32});
997 if (isInt8PtxType(getMultiplicandAPtxType().value()))
998 allowedShapes.push_back({8, 8, 16});
999 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1000 allowedShapes.push_back({8, 8, 128});
1001 }
1002 }
1003
1004 std::string errorMessage;
1005 llvm::raw_string_ostream errorStream(errorMessage);
1006
1007 // Check that we matched an existing shape/dtype combination.
1008 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1009 !llvm::is_contained(allowedShapes, mmaShape)) {
1010 errorStream << "unimplemented variant for MMA shape <";
1011 llvm::interleaveComma(mmaShape, errorStream);
1012 errorStream << ">";
1013 return emitOpError(errorMessage);
1014 }
1015
1016 // Verify the operand types for segments of A, B, and C operands.
1017 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1018 for (const auto &iter : llvm::enumerate(
1019 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1020 auto spec = this->getODSOperandIndexAndLength(iter.index());
1021 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1022 operand_type_begin() + spec.first +
1023 spec.second);
1024 bool match = llvm::is_contained(iter.value(), operandTySeg);
1025
1026 if (!match) {
1027 errorStream << "Could not match types for the "
1028 << operandNames[iter.index()]
1029 << " operands; expected one of ";
1030 for (const auto &x : iter.value()) {
1031 errorStream << x.size() << "x" << x[0] << " ";
1032 }
1033 errorStream << "but got ";
1034 llvm::interleaveComma(operandTySeg, errorStream);
1035 return emitOpError(errorMessage);
1036 }
1037 }
1038
1039 // Check the result type
1040 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1041 return expectedResultType == getResult().getType();
1042 })) {
1043 errorStream
1044 << "Could not match allowed types for the result; expected one of ";
1045 llvm::interleaveComma(expectedResult, errorStream);
1046 errorStream << " but got " << getResult().getType();
1047 return emitOpError(errorMessage);
1048 }
1049
1050 // Ensure that binary MMA variants have a b1 MMA operation defined.
1051 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1052 return emitOpError("op requires " + getB1OpAttrName().strref() +
1053 " attribute");
1054 }
1055
1056 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1057 // attribute.
1058 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1059 isInt8PtxType(*getMultiplicandAPtxType())) {
1060 if (!getIntOverflowBehavior())
1061 return emitOpError("op requires " +
1062 getIntOverflowBehaviorAttrName().strref() +
1063 " attribute");
1064 }
1065
1066 // Validate layout combinations. According to the operation description, most
1067 // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
1068 // can use other layout combinations.
1069 bool isM8N8K4_F16 =
1070 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1071 getMultiplicandAPtxType() == MMATypes::f16);
1072
1073 if (!isM8N8K4_F16) {
1074 // For all other shapes/types, layoutA must be row and layoutB must be col
1075 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1076 return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
1077 "layoutB = #nvvm.mma_layout<col> for shape <")
1078 << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
1079 << "> with element types " << *getMultiplicandAPtxType() << " and "
1080 << *getMultiplicandBPtxType()
1081 << ". Only m8n8k4 with f16 supports other layouts.";
1082 }
1083 }
1084
1085 return success();
1086}
1087
1088MMATypes MmaSpOp::accumPtxType() {
1089 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1090 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
1091 assert(val.has_value() && "accumulator PTX type should always be inferrable");
1092 return val.value();
1093}
1094
1095MMATypes MmaSpOp::resultPtxType() {
1096 std::optional<mlir::NVVM::MMATypes> val =
1097 MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
1098 assert(val.has_value() && "result PTX type should always be inferrable");
1099 return val.value();
1100}
1101
1103MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1104 llvm::IRBuilderBase &builder) {
1105 auto thisOp = cast<NVVM::MmaSpOp>(op);
1106
1107 // Get operands
1109 for (mlir::Value v : thisOp.getOperands())
1110 args.push_back(mt.lookupValue(v));
1111
1112 // Get intrinsic ID using the existing getIntrinsicID method
1113 auto intId = MmaSpOp::getIntrinsicID(
1114 thisOp.getShape().getM(), thisOp.getShape().getN(),
1115 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1116 thisOp.getOrderedMetadata(), thisOp.getKind(),
1117 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1118 thisOp.accumPtxType(), thisOp.resultPtxType());
1119
1120 return {intId, args};
1121}
1122
1123void MmaSpOp::print(OpAsmPrinter &p) {
1124 SmallVector<Type, 4> regTypes;
1125 struct MMAOperandFragment {
1126 StringRef operandName;
1127 StringRef ptxTypeAttr;
1128 SmallVector<Value, 4> regs;
1129 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1130 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1131 };
1132
1133 std::array<MMAOperandFragment, 5> frags{
1134 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1135 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1136 MMAOperandFragment("C", ""), MMAOperandFragment("sparseMetadata", ""),
1137 MMAOperandFragment("selector", "")};
1138 SmallVector<StringRef, 4> ignoreAttrNames{
1139 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1140
1141 // Handle variadic operands A, B, C
1142 for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1143 auto &frag = frags[fragIdx];
1144 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1145 for (auto operandIdx = varOperandSpec.first;
1146 operandIdx < varOperandSpec.first + varOperandSpec.second;
1147 operandIdx++) {
1148 frag.regs.push_back(this->getOperand(operandIdx));
1149 if (operandIdx == varOperandSpec.first) {
1150 regTypes.push_back(this->getOperand(operandIdx).getType());
1151 }
1152 }
1153 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1154 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
1155 if (inferredType)
1156 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1157 }
1158
1159 // Handle sparse metadata and selector (single operands)
1160 frags[3].regs.push_back(getSparseMetadata());
1161 frags[4].regs.push_back(getSparsitySelector());
1162
1163 auto printMmaSpOperand = [&](const MMAOperandFragment &frag) -> void {
1164 p << " " << frag.operandName;
1165 p << "[";
1166 p.printOperands(frag.regs);
1167 p << "]";
1168 };
1169
1170 for (const auto &frag : frags)
1171 printMmaSpOperand(frag);
1172
1173 p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
1174 p << " : ";
1175 p << "(";
1176 for (int i = 0; i < 3; ++i) {
1177 p << regTypes[i];
1178 if (i < 2)
1179 p << ", ";
1180 }
1181 p << ") -> " << getResult().getType();
1182}
1183
1184void MmaSpOp::build(
1185 OpBuilder &builder, OperationState &result, Type resultType,
1186 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1187 Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
1188 std::optional<MMAIntOverflow> intOverflow,
1189 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1190
1191 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1192 MLIRContext *ctx = builder.getContext();
1193 result.addAttribute(
1194 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1195
1196 result.addOperands(operandA);
1197 result.addOperands(operandB);
1198 result.addOperands(operandC);
1199 result.addOperands(sparseMetadata);
1200 result.addOperands(sparsitySelector);
1201
1202 if (multiplicandPtxTypes) {
1203 result.addAttribute("multiplicandAPtxType",
1204 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1205 result.addAttribute("multiplicandBPtxType",
1206 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1207 } else {
1208 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1209 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1210 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1211 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1212 }
1213
1214 if (intOverflow.has_value())
1215 result.addAttribute("intOverflowBehavior",
1216 MMAIntOverflowAttr::get(ctx, *intOverflow));
1217
1218 result.addTypes(resultType);
1219 result.addAttribute(
1220 MmaSpOp::getOperandSegmentSizeAttr(),
1221 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
1222 static_cast<int32_t>(operandB.size()),
1223 static_cast<int32_t>(operandC.size()), 1,
1224 1})); // sparseMetadata and sparsitySelector
1225}
1226
1227ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
1228 struct MMAOperandFragment {
1229 std::optional<MMATypes> elemtype;
1230 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1231 SmallVector<Type> regTypes;
1232 };
1233
1234 Builder &builder = parser.getBuilder();
1235 std::array<MMAOperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
1236
1237 NamedAttrList namedAttributes;
1238
1239 // A helper to parse the operand segments.
1240 auto parseMmaSpOperand = [&](StringRef operandName,
1241 MMAOperandFragment &frag) -> LogicalResult {
1242 if (parser.parseKeyword(operandName).failed())
1243 return failure();
1244 if (parser
1245 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
1246 .failed())
1247 return failure();
1248 return success();
1249 };
1250
1251 // Parse the operand segments.
1252 if (parseMmaSpOperand("A", frags[0]).failed())
1253 return failure();
1254 if (parseMmaSpOperand("B", frags[1]).failed())
1255 return failure();
1256 if (parseMmaSpOperand("C", frags[2]).failed())
1257 return failure();
1258 if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
1259 return failure();
1260 if (parseMmaSpOperand("selector", frags[4]).failed())
1261 return failure();
1262
1263 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1264 return failure();
1265
1266 // Parse the type specification and resolve operands.
1267 SmallVector<Type, 3> operandTypes;
1268 if (failed(parser.parseColon()))
1269 return failure();
1270 if (failed(parser.parseLParen()))
1271 return failure();
1272 if (failed(parser.parseTypeList(operandTypes)))
1273 return failure();
1274 if (failed(parser.parseRParen()))
1275 return failure();
1276 if (operandTypes.size() != 3)
1277 return parser.emitError(
1278 parser.getNameLoc(),
1279 "expected one type for each operand segment but got " +
1280 Twine(operandTypes.size()) + " types");
1281 for (const auto &iter : llvm::enumerate(operandTypes)) {
1282 auto &frag = frags[iter.index()];
1283 frag.regTypes.resize(frag.regs.size(), iter.value());
1284 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
1285 parser.getNameLoc(), result.operands)))
1286 return failure();
1287 frag.elemtype =
1288 MmaOp::inferOperandMMAType(frag.regTypes[0],
1289 /*isAccumulator*/ iter.index() >= 2);
1290 }
1291
1292 Type resultType;
1293 if (parser.parseArrow() || parser.parseType(resultType))
1294 return failure();
1295 frags[5].elemtype =
1296 MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
1297
1298 // Resolve sparse metadata and selector (assume i32 type)
1299 Type i32Type = builder.getIntegerType(32);
1300 if (parser
1301 .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
1302 result.operands)
1303 .failed())
1304 return failure();
1305 if (parser
1306 .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
1307 result.operands)
1308 .failed())
1309 return failure();
1310
1311 std::array<StringRef, 2> names{"multiplicandAPtxType",
1312 "multiplicandBPtxType"};
1313 for (unsigned idx = 0; idx < names.size(); idx++) {
1314 const auto &frag = frags[idx];
1315 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
1316 if (!frag.elemtype.has_value() && !attr.has_value()) {
1317 return parser.emitError(
1318 parser.getNameLoc(),
1319 "attribute " + names[idx] +
1320 " is not provided explicitly and cannot be inferred");
1321 }
1322 if (!attr.has_value())
1323 result.addAttribute(
1324 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
1325 }
1326
1327 result.addTypes(resultType);
1328 if (!namedAttributes.empty())
1329 result.addAttributes(namedAttributes);
1330 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1331 builder.getDenseI32ArrayAttr({
1332 static_cast<int32_t>(frags[0].regs.size()),
1333 static_cast<int32_t>(frags[1].regs.size()),
1334 static_cast<int32_t>(frags[2].regs.size()),
1335 1, // sparseMetadata
1336 1 // sparsitySelector
1337 }));
1338 return success();
1339}
1340
1341LogicalResult MmaSpOp::verify() {
1342 MLIRContext *context = getContext();
1343 auto f16Ty = Float16Type::get(context);
1344 auto i32Ty = IntegerType::get(context, 32);
1345 auto f16x2Ty = VectorType::get(2, f16Ty);
1346 auto f32Ty = Float32Type::get(context);
1347 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1348 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1349
1350 auto s32x4StructTy =
1351 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1352 auto f32x8StructTy =
1353 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
1354 auto f16x2x2StructTy =
1355 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1356 auto f32x4StructTy =
1357 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1358 auto s32x2StructTy =
1359 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1360
1361 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1362 getShapeAttr().getK()};
1363
1364 // These variables define the set of allowed data types for matrices A, B, C,
1365 // and result.
1366 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
1367 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
1368 AllowedShapes allowedShapes;
1369 AllowedTypes expectedA;
1370 AllowedTypes expectedB;
1371 AllowedTypes expectedC;
1372 SmallVector<Type> expectedResult;
1373
1374 // When M = 16, we just need to calculate the number of 8xk tiles, where
1375 // k is a factor that depends on the data type.
1376 if (mmaShape[0] == 16) {
1377 int64_t kFactor;
1378 Type multiplicandFragType;
1379 switch (*getMultiplicandAPtxType()) {
1380 case MMATypes::tf32:
1381 kFactor = 4;
1382 multiplicandFragType = i32Ty;
1383 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1384 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1385 // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
1386 allowedShapes.push_back({16, 8, 8});
1387 allowedShapes.push_back({16, 8, 16});
1388 break;
1389 case MMATypes::bf16:
1390 kFactor = 8;
1391 multiplicandFragType = i32Ty;
1392 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1393 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1394 // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
1395 allowedShapes.push_back({16, 8, 16});
1396 allowedShapes.push_back({16, 8, 32});
1397 break;
1398 case MMATypes::f16:
1399 kFactor = 8;
1400 multiplicandFragType = f16x2Ty;
1401 expectedResult.push_back(f16x2x2StructTy);
1402 expectedResult.push_back(f32x4StructTy);
1403 // Sparse MMA supports m16n8k16 and m16n8k32 for f16
1404 allowedShapes.push_back({16, 8, 16});
1405 allowedShapes.push_back({16, 8, 32});
1406 break;
1407 case MMATypes::s4:
1408 case MMATypes::u4:
1409 kFactor = 32;
1410 // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
1411 allowedShapes.push_back({16, 8, 64});
1412 allowedShapes.push_back({16, 8, 128});
1413 break;
1414 case MMATypes::s8:
1415 case MMATypes::u8:
1416 kFactor = 16;
1417 // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
1418 allowedShapes.push_back({16, 8, 32});
1419 allowedShapes.push_back({16, 8, 64});
1420 break;
1421 case MMATypes::e4m3:
1422 case MMATypes::e5m2:
1423 case MMATypes::e3m2:
1424 case MMATypes::e2m3:
1425 case MMATypes::e2m1:
1426 kFactor = 16;
1427 multiplicandFragType = i32Ty;
1428 expectedResult.push_back(f16x2x2StructTy);
1429 expectedResult.push_back(f32x4StructTy);
1430 // Sparse MMA supports m16n8k64 for FP8 types
1431 allowedShapes.push_back({16, 8, 64});
1432 break;
1433 default:
1434 return emitError("invalid shape or multiplicand type: ")
1435 << getMultiplicandAPtxType().value();
1436 }
1437
1438 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1439 expectedResult.push_back(s32x4StructTy);
1440 expectedC.emplace_back(4, i32Ty);
1441 multiplicandFragType = i32Ty;
1442 } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1443 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1444 // FP8 types
1445 expectedC.emplace_back(2, f16x2Ty);
1446 expectedC.emplace_back(4, f32Ty);
1447 } else {
1448 expectedC.emplace_back(2, f16x2Ty);
1449 expectedC.emplace_back(4, f32Ty);
1450 }
1451
1452 // For sparse MMA, A operand is compressed (2:4 sparsity means half the
1453 // elements)
1454 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1455 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1456 expectedA.emplace_back(unitA, multiplicandFragType);
1457 expectedB.emplace_back(unitB, multiplicandFragType);
1458
1459 if (resultPtxType() != accumPtxType())
1460 return emitOpError("ctype does not match dtype");
1461 }
1462
1463 // In the M=8 case, there is only 1 possible case per data type.
1464 if (mmaShape[0] == 8) {
1465 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1466 expectedA.emplace_back(2, f16x2Ty);
1467 expectedB.emplace_back(2, f16x2Ty);
1468 expectedResult.push_back(f16x2x4StructTy);
1469 expectedResult.push_back(f32x8StructTy);
1470 expectedC.emplace_back(4, f16x2Ty);
1471 expectedC.emplace_back(8, f32Ty);
1472 allowedShapes.push_back({8, 8, 4});
1473 }
1474 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1475 Type f64Ty = Float64Type::get(context);
1476 expectedA.emplace_back(1, f64Ty);
1477 expectedB.emplace_back(1, f64Ty);
1478 expectedC.emplace_back(2, f64Ty);
1479 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1480 context, SmallVector<Type>(2, f64Ty)));
1481 allowedShapes.push_back({8, 8, 4});
1482 }
1483 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1484 expectedA.push_back({i32Ty});
1485 expectedB.push_back({i32Ty});
1486 expectedC.push_back({i32Ty, i32Ty});
1487 expectedResult.push_back(s32x2StructTy);
1488 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1489 allowedShapes.push_back({8, 8, 32});
1490 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1491 allowedShapes.push_back({8, 8, 16});
1492 }
1493 }
1494
1495 std::string errorMessage;
1496 llvm::raw_string_ostream errorStream(errorMessage);
1497
1498 // Check that we matched an existing shape/dtype combination.
1499 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1500 !llvm::is_contained(allowedShapes, mmaShape)) {
1501 errorStream << "unimplemented variant for MMA shape <";
1502 llvm::interleaveComma(mmaShape, errorStream);
1503 errorStream << ">";
1504 return emitOpError(errorMessage);
1505 }
1506
1507 // Verify the operand types for segments of A, B, and C operands.
1508 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1509 for (const auto &iter : llvm::enumerate(
1510 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1511 auto spec = this->getODSOperandIndexAndLength(iter.index());
1512 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1513 operand_type_begin() + spec.first +
1514 spec.second);
1515 bool match = llvm::is_contained(iter.value(), operandTySeg);
1516
1517 if (!match) {
1518 errorStream << "Could not match types for the "
1519 << operandNames[iter.index()]
1520 << " operands; expected one of ";
1521 for (const auto &x : iter.value()) {
1522 errorStream << x.size() << "x" << x[0] << " ";
1523 }
1524 errorStream << "but got ";
1525 llvm::interleaveComma(operandTySeg, errorStream);
1526 return emitOpError(errorMessage);
1527 }
1528 }
1529
1530 // Check the result type
1531 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1532 return expectedResultType == getResult().getType();
1533 })) {
1534 errorStream
1535 << "Could not match allowed types for the result; expected one of ";
1536 llvm::interleaveComma(expectedResult, errorStream);
1537 errorStream << " but got " << getResult().getType();
1538 return emitOpError(errorMessage);
1539 }
1540
1541 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1542 // attribute.
1543 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1544 isInt8PtxType(*getMultiplicandAPtxType())) {
1545 if (!getIntOverflowBehavior())
1546 return emitOpError("op requires " +
1547 getIntOverflowBehaviorAttrName().strref() +
1548 " attribute");
1549 }
1550
1551 // Validate sparse metadata type (should be i32)
1552 if (!getSparseMetadata().getType().isInteger(32)) {
1553 return emitOpError() << "sparse metadata must be i32 type";
1554 }
1555
1556 // Validate sparsity selector type (should be i32)
1557 if (!getSparsitySelector().getType().isInteger(32)) {
1558 return emitOpError() << "sparsity selector must be i32 type";
1559 }
1560
1561 return success();
1562}
1563
1564//===----------------------------------------------------------------------===//
1565// MMA Block Scale Operations - Shared Helpers
1566//===----------------------------------------------------------------------===//
1567
1568namespace {
1569// Shared structure for MMA operand fragments (A, B, C)
1570struct MMAOperandFragment {
1571 StringRef operandName;
1572 StringRef ptxTypeAttr;
1573 SmallVector<Value, 4> regs;
1574 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1575 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1576};
1577} // namespace
1578
1579// Helper to print operand list in the format: name[operands]
1580static void printOperandList(OpAsmPrinter &p, StringRef name,
1581 ArrayRef<Value> operands) {
1582 p << " " << name << "[";
1583 p.printOperands(operands);
1584 p << "]";
1585}
1586
1587// Helper to parse operand list in the format: name[operands]
1588static LogicalResult
1589parseMmaOperand(OpAsmParser &parser, StringRef operandName,
1591 if (parser.parseKeyword(operandName).failed())
1592 return failure();
1594 .failed())
1595 return failure();
1596 return success();
1597}
1598
1599// Helper to process operand fragments and determine which attributes can be
1600// inferred
1601template <typename Op>
1602static void
1603processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
1604 SmallVectorImpl<Type> &regTypes,
1605 SmallVectorImpl<StringRef> &ignoreAttrNames) {
1606 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1607 auto &frag = frags[fragIdx];
1608 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1609 for (auto operandIdx = varOperandSpec.first;
1610 operandIdx < varOperandSpec.first + varOperandSpec.second;
1611 operandIdx++) {
1612 frag.regs.push_back(op.getOperand(operandIdx));
1613 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1614 regTypes.push_back(op.getOperand(operandIdx).getType());
1615 }
1616 }
1617 if (fragIdx < 2) {
1618 regTypes.push_back(frag.regs[0].getType());
1619 }
1620 std::optional<MMATypes> inferredType =
1621 MmaOp::inferOperandMMAType(regTypes.back(),
1622 /*isAccumulator=*/fragIdx >= 2);
1623 if (inferredType)
1624 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1625 }
1626}
1627
1628// Helper to parse type signature: (A_type, B_type, C_type)
1629static LogicalResult
1631 SmallVectorImpl<Type> &operandTypes) {
1632 if (parser.parseColon().failed() || parser.parseLParen().failed())
1633 return failure();
1634
1635 auto typeParser = [&]() {
1636 Type ty;
1637 if (parser.parseType(ty).failed())
1638 return failure();
1639 operandTypes.push_back(ty);
1640 return success();
1641 };
1642 if (parser.parseCommaSeparatedList(typeParser))
1643 return failure();
1644
1645 if (operandTypes.size() != 3)
1646 return parser.emitError(parser.getCurrentLocation(),
1647 "expected exactly 3 types");
1648
1649 return parser.parseRParen();
1650}
1651
1652// Helper to infer and set multiplicand PTX type attributes
1653static void
1655 const SmallVectorImpl<Type> &operandTypes) {
1656 if (!attrs.get("multiplicandAPtxType")) {
1657 if (auto inferredType =
1658 MmaOp::inferOperandMMAType(operandTypes[0], false)) {
1659 attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1660 }
1661 }
1662 if (!attrs.get("multiplicandBPtxType")) {
1663 if (auto inferredType =
1664 MmaOp::inferOperandMMAType(operandTypes[1], false)) {
1665 attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1666 }
1667 }
1668}
1669
1670// Helper to add common block scale properties
1671template <typename OpType>
1674 ScaleVecSize scaleVecSize,
1675 BlockScaleFormat blockScaleFormat,
1676 MMABlockScaleKind kind) {
1677 MLIRContext *ctx = builder.getContext();
1678 auto &properties = result.getOrAddProperties<typename OpType::Properties>();
1679 properties.setShape(
1680 builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1681 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1682 properties.setBlockScaleFormat(
1683 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1684 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1685}
1686
1687// Helper to infer and add multiplicand PTX types to builder
1690 ValueRange operandB,
1691 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1692 if (multiplicandPtxTypes) {
1693 result.addAttribute("multiplicandAPtxType",
1694 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1695 result.addAttribute("multiplicandBPtxType",
1696 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1697 } else {
1698 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1699 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1700 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1701 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1702 }
1703}
1704
1705// Template helper for common accumPtxType/resultPtxType implementation
1706template <typename OpTy>
1707static MMATypes inferPtxTypeFromResult(OpTy op) {
1708 return *MmaOp::inferOperandMMAType(
1709 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1710 /*isAccumulator=*/true);
1711}
1712
1713//===----------------------------------------------------------------------===//
1714// MmaBlockScaleOp
1715//===----------------------------------------------------------------------===//
1716
1717void MmaBlockScaleOp::print(OpAsmPrinter &p) {
1718 SmallVector<Type, 4> regTypes;
1719 std::array<MMAOperandFragment, 3> frags{
1720 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1721 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1722 MMAOperandFragment("C", "")};
1723 SmallVector<StringRef, 4> ignoreAttrNames{
1724 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1725
1726 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1727
1728 // Print A, B, C operands
1729 for (const auto &frag : frags)
1730 printOperandList(p, frag.operandName, frag.regs);
1731
1732 // Print scale operands
1733 printOperandList(p, "scaleA",
1734 {getScaleAData(), getByteIdA(), getThreadIdA()});
1735 printOperandList(p, "scaleB",
1736 {getScaleBData(), getByteIdB(), getThreadIdB()});
1737
1738 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1739
1740 // Print type signature
1741 p << " : (";
1742 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1743 frags[1].regs[0].getType(),
1744 frags[2].regs[0].getType()},
1745 p);
1746 p << ")";
1747 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1748}
1749
1750ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
1752 struct LocalOperandFragment {
1753 std::optional<MMATypes> elemtype;
1754 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1755 };
1756
1757 Builder &builder = parser.getBuilder();
1758 std::array<LocalOperandFragment, 3> frags;
1759 NamedAttrList namedAttributes;
1760
1761 // Parse A[...] B[...] C[...]
1762 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1763 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1764 parseMmaOperand(parser, "C", frags[2].regs).failed())
1765 return failure();
1766
1767 // Parse scale operands: scaleA[...] scaleB[...]
1768 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
1769 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
1770 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
1771 return failure();
1772
1773 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1774 return failure();
1775
1776 // Parse type signature
1777 SmallVector<Type, 3> operandTypes;
1778 if (parseMmaTypeSignature(parser, operandTypes).failed())
1779 return failure();
1780
1781 // Parse result type
1782 SmallVector<Type, 1> resultTypes;
1783 if (parser.parseArrowTypeList(resultTypes).failed())
1784 return failure();
1785
1786 // Infer element types and resolve operands
1787 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
1788 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1789 /*isAccumulator=*/idx >= 2);
1790 if (parser
1791 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
1792 result.operands)
1793 .failed())
1794 return failure();
1795 }
1796
1797 // Resolve scale operands
1798 SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
1799 builder.getI16Type()};
1800 if (parser
1801 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
1802 result.operands)
1803 .failed() ||
1804 parser
1805 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
1806 result.operands)
1807 .failed())
1808 return failure();
1809
1810 // Add attributes
1811 result.addAttributes(namedAttributes);
1812 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
1813 operandTypes);
1814
1815 result.addTypes(resultTypes);
1816 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1817 builder.getDenseI32ArrayAttr({
1818 static_cast<int32_t>(frags[0].regs.size()),
1819 static_cast<int32_t>(frags[1].regs.size()),
1820 static_cast<int32_t>(frags[2].regs.size()),
1821 1, // scaleAData
1822 1, // byteIdA
1823 1, // threadIdA
1824 1, // scaleBData
1825 1, // byteIdB
1826 1 // threadIdB
1827 }));
1828 return success();
1829}
1830
1831void MmaBlockScaleOp::build(
1832 OpBuilder &builder, OperationState &result, Type resultType,
1833 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1834 Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
1835 Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
1836 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1837 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1838 MMABlockScaleKind kind) {
1839 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1840
1842 blockScaleFormat, kind);
1843
1844 result.addOperands(operandA);
1845 result.addOperands(operandB);
1846 result.addOperands(operandC);
1847 result.addOperands(
1848 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1849
1850 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
1851 multiplicandPtxTypes);
1852
1853 result.addTypes(resultType);
1854 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1855 builder.getDenseI32ArrayAttr({
1856 static_cast<int32_t>(operandA.size()),
1857 static_cast<int32_t>(operandB.size()),
1858 static_cast<int32_t>(operandC.size()),
1859 1, // scaleAData
1860 1, // byteIdA
1861 1, // threadIdA
1862 1, // scaleBData
1863 1, // byteIdB
1864 1 // threadIdB
1865 }));
1866}
1867
1868NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
1869 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1870 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1871
1873 // Add A, B, C operands
1874 for (Value operand : curOp.getOperandA())
1875 args.push_back(mt.lookupValue(operand));
1876 for (Value operand : curOp.getOperandB())
1877 args.push_back(mt.lookupValue(operand));
1878 for (Value operand : curOp.getOperandC())
1879 args.push_back(mt.lookupValue(operand));
1880
1881 // Add scale operands
1882 args.push_back(mt.lookupValue(curOp.getScaleAData()));
1883 args.push_back(mt.lookupValue(curOp.getByteIdA()));
1884 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
1885 args.push_back(mt.lookupValue(curOp.getScaleBData()));
1886 args.push_back(mt.lookupValue(curOp.getByteIdB()));
1887 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
1888
1889 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1890 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1891 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1892 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
1893 curOp.getBlockScaleFormat(), curOp.getKind());
1894
1895 return {intId, args};
1896}
1897
1898LogicalResult MmaBlockScaleOp::verify() {
1899 LogicalResult result = success();
1900 int m = getShape().getM();
1901 int n = getShape().getN();
1902 int k = getShape().getK();
1903
1904 if (m == 16 && n == 8 && k == 64) {
1905 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1906 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1908 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1909 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1910 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1912 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1913 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1915 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1916 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1917 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1918 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1919 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1920 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
1921 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
1922 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
1923 "attributes for mma.m16n8k64.mxf4nvf4");
1924 } else {
1925 result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
1926 }
1927 } else if (m == 16 && n == 8 && k == 32) {
1928 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
1929 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
1930 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
1931 result =
1932 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
1933 "attributes for mma.m16n8k32");
1934 } else {
1935 result = emitOpError("unsupported Geom for mma with block scaling");
1936 }
1937 return result;
1938}
1939
1940//===----------------------------------------------------------------------===//
1941// MmaSpBlockScaleOp
1942//===----------------------------------------------------------------------===//
1943
1944void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
1945 SmallVector<Type, 4> regTypes;
1946 std::array<MMAOperandFragment, 3> frags{
1947 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1948 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1949 MMAOperandFragment("C", "")};
1950 SmallVector<StringRef, 4> ignoreAttrNames{
1951 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
1952
1953 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1954
1955 // Print A, B, C operands
1956 for (const auto &frag : frags)
1957 printOperandList(p, frag.operandName, frag.regs);
1958
1959 // Print sparse-specific operands
1960 printOperandList(p, "sparseMetadata", {getSparseMetadata()});
1961 printOperandList(p, "selector", {getSparsitySelector()});
1962
1963 // Print scale operands
1964 printOperandList(p, "scaleA",
1965 {getScaleAData(), getByteIdA(), getThreadIdA()});
1966 printOperandList(p, "scaleB",
1967 {getScaleBData(), getByteIdB(), getThreadIdB()});
1968
1969 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1970
1971 // Print type signature
1972 p << " : (";
1973 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1974 frags[1].regs[0].getType(),
1975 frags[2].regs[0].getType()},
1976 p);
1977 p << ")";
1978 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1979}
1980
1981ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
1983 struct LocalOperandFragment {
1984 std::optional<MMATypes> elemtype;
1985 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1986 };
1987
1988 Builder &builder = parser.getBuilder();
1989 std::array<LocalOperandFragment, 3> frags;
1990 NamedAttrList namedAttributes;
1991
1992 // Parse A[...] B[...] C[...]
1993 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1994 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1995 parseMmaOperand(parser, "C", frags[2].regs).failed())
1996 return failure();
1997
1998 // Parse sparse-specific operands
2000 selectorOperands;
2001 if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
2002 parseMmaOperand(parser, "selector", selectorOperands).failed())
2003 return failure();
2004
2005 // Parse scale operands
2006 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
2007 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
2008 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
2009 return failure();
2010
2011 if (parser.parseOptionalAttrDict(namedAttributes).failed())
2012 return failure();
2013
2014 // Parse type signature
2015 SmallVector<Type, 3> operandTypes;
2016 if (parseMmaTypeSignature(parser, operandTypes).failed())
2017 return failure();
2018
2019 // Parse result type
2020 SmallVector<Type, 1> resultTypes;
2021 if (parser.parseArrowTypeList(resultTypes).failed())
2022 return failure();
2023
2024 // Infer element types and resolve operands
2025 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
2026 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2027 /*isAccumulator=*/idx >= 2);
2028 if (parser
2029 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
2030 result.operands)
2031 .failed())
2032 return failure();
2033 }
2034
2035 // Resolve sparse metadata and selector
2036 Type i32Type = builder.getI32Type();
2037 if (parser
2038 .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
2039 result.operands)
2040 .failed() ||
2041 parser
2042 .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
2043 result.operands)
2044 .failed())
2045 return failure();
2046
2047 // Resolve scale operands
2048 SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
2049 builder.getI16Type()};
2050 if (parser
2051 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
2052 result.operands)
2053 .failed() ||
2054 parser
2055 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
2056 result.operands)
2057 .failed())
2058 return failure();
2059
2060 // Add attributes
2061 result.addAttributes(namedAttributes);
2062 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
2063 operandTypes);
2064
2065 // orderedMetadata is mandatory
2066 if (!result.attributes.get("orderedMetadata"))
2067 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2068
2069 result.addTypes(resultTypes);
2070 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2071 builder.getDenseI32ArrayAttr({
2072 static_cast<int32_t>(frags[0].regs.size()),
2073 static_cast<int32_t>(frags[1].regs.size()),
2074 static_cast<int32_t>(frags[2].regs.size()),
2075 1, // sparseMetadata
2076 1, // sparsitySelector
2077 1, // scaleAData
2078 1, // byteIdA
2079 1, // threadIdA
2080 1, // scaleBData
2081 1, // byteIdB
2082 1 // threadIdB
2083 }));
2084 return success();
2085}
2086
2087void MmaSpBlockScaleOp::build(
2088 OpBuilder &builder, OperationState &result, Type resultType,
2089 ValueRange operandA, ValueRange operandB, ValueRange operandC,
2090 Value sparseMetadata, Value sparsitySelector, Value scaleAData,
2091 Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
2092 Value threadIdB, ArrayRef<int64_t> shape,
2093 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2094 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2095 MMABlockScaleKind kind) {
2096 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
2097
2099 builder, result, shape, scaleVecSize, blockScaleFormat, kind);
2100 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2101
2102 result.addOperands(operandA);
2103 result.addOperands(operandB);
2104 result.addOperands(operandC);
2105 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2106 threadIdA, scaleBData, byteIdB, threadIdB});
2107
2108 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
2109 multiplicandPtxTypes);
2110
2111 result.addTypes(resultType);
2112 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2113 builder.getDenseI32ArrayAttr({
2114 static_cast<int32_t>(operandA.size()),
2115 static_cast<int32_t>(operandB.size()),
2116 static_cast<int32_t>(operandC.size()),
2117 1, // sparseMetadata
2118 1, // sparsitySelector
2119 1, // scaleAData
2120 1, // byteIdA
2121 1, // threadIdA
2122 1, // scaleBData
2123 1, // byteIdB
2124 1 // threadIdB
2125 }));
2126}
2127
2128NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
2129 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2130 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2131
2133 // Add A, B, C operands
2134 for (Value operand : curOp.getOperandA())
2135 args.push_back(mt.lookupValue(operand));
2136 for (Value operand : curOp.getOperandB())
2137 args.push_back(mt.lookupValue(operand));
2138 for (Value operand : curOp.getOperandC())
2139 args.push_back(mt.lookupValue(operand));
2140
2141 // Add sparse metadata and selector
2142 args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
2143 args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
2144
2145 // Add scale operands
2146 args.push_back(mt.lookupValue(curOp.getScaleAData()));
2147 args.push_back(mt.lookupValue(curOp.getByteIdA()));
2148 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
2149 args.push_back(mt.lookupValue(curOp.getScaleBData()));
2150 args.push_back(mt.lookupValue(curOp.getByteIdB()));
2151 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
2152
2153 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2154 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2155 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2156 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
2157 curOp.getBlockScaleFormat(), curOp.getKind());
2158
2159 return {intId, args};
2160}
2161
2162LogicalResult MmaSpBlockScaleOp::verify() {
2163 // Check that orderedMetadata is present
2164 if (!getOrderedMetadata()) {
2165 return emitOpError("'orderedMetadata' attribute is mandatory");
2166 }
2167
2168 LogicalResult result = success();
2169 int m = getShape().getM();
2170 int n = getShape().getN();
2171 int k = getShape().getK();
2172
2173 if (m == 16 && n == 8 && k == 128) {
2174 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2175 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2177 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2178 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2179 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2181 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2182 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2184 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2185 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2186 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2187 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2188 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2189 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2190 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2191 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
2192 "attributes for mma.m16n8k128.mxf4nvf4");
2193 } else {
2194 result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
2195 }
2196 } else if (m == 16 && n == 8 && k == 64) {
2197 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2198 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2199 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2200 result =
2201 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
2202 "attributes for mma.m16n8k64");
2203 } else {
2204 result = emitOpError("unsupported Geom for sparse mma with block scaling");
2205 }
2206 return result;
2207}
2208
2209LogicalResult ShflOp::verify() {
2210 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2211
2212 auto verifyTypeError = [&](Twine desc, Type expectedType,
2213 Type actualType) -> LogicalResult {
2214 return emitOpError("expected " + desc + " to be of type ")
2215 << expectedType << " but got " << actualType << " instead";
2216 };
2217
2218 if (returnStructType) {
2219 if (!getReturnValueAndIsValid())
2220 return emitOpError("\"return_value_and_is_valid\" attribute must be "
2221 "specified when the return type is a struct type");
2222
2223 if (returnStructType.getBody().size() != 2)
2224 return emitOpError("expected return type to be a two-element struct");
2225
2226 llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
2227 auto resultType = returnStruct[0];
2228 if (resultType != getVal().getType())
2229 return verifyTypeError("first element in the returned struct",
2230 getVal().getType(), resultType);
2231
2232 auto predicateType = returnStruct[1];
2233 if (!predicateType.isInteger(1))
2234 return verifyTypeError("second element in the returned struct",
2235 mlir::IntegerType::get(getContext(), 1),
2236 predicateType);
2237 } else {
2238 if (getReturnValueAndIsValid())
2239 return emitOpError("expected return type to be a two-element struct");
2240
2241 if (getType() != getVal().getType())
2242 return verifyTypeError("return type", getVal().getType(), getType());
2243 }
2244 return success();
2245}
2246
2247std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
2248 NVVM::MMAFrag frag, int nRow,
2249 int nCol,
2250 MLIRContext *context) {
2251 unsigned numberElements = 0;
2252 Type elementType;
2253 OpBuilder builder(context);
2254 Type f16x2 = VectorType::get(2, builder.getF16Type());
2255 if (type == NVVM::MMATypes::f16) {
2256 elementType = f16x2;
2257 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2258 numberElements = 8;
2259 else
2260 numberElements = 4;
2261 } else if (type == NVVM::MMATypes::f32) {
2262 elementType = builder.getF32Type();
2263 numberElements = 8;
2264 } else if (type == NVVM::MMATypes::f64) {
2265 elementType = builder.getF64Type();
2266 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2267 numberElements = 1;
2268 else
2269 numberElements = 2;
2270 } else if (type == NVVM::MMATypes::tf32) {
2271 elementType = builder.getI32Type();
2272 numberElements = 4;
2273 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2274 elementType = builder.getI32Type();
2275 int parallelSize = 0;
2276 if (frag == NVVM::MMAFrag::a)
2277 parallelSize = nRow;
2278 if (frag == NVVM::MMAFrag::b)
2279 parallelSize = nCol;
2280
2281 // m == 16 && n == 16 && k == 16
2282 if (parallelSize == 16)
2283 numberElements = 2;
2284 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
2285 else if (parallelSize == 8)
2286 numberElements = 1;
2287 else if (parallelSize == 32)
2288 numberElements = 4;
2289 } else if (type == NVVM::MMATypes::s32) {
2290 elementType = builder.getI32Type();
2291 numberElements = 8;
2292 }
2293 assert(numberElements != 0 && elementType != nullptr);
2294 return std::make_pair(elementType, numberElements);
2295}
2296
2297static std::pair<mlir::Type, unsigned>
2298inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
2299 int k, MLIRContext *context) {
2300 int nRow, nCol;
2301 if (frag == NVVM::MMAFrag::a) {
2302 nRow = m;
2303 nCol = k;
2304 } else if (frag == NVVM::MMAFrag::b) {
2305 nRow = k;
2306 nCol = n;
2307 } else {
2308 nRow = m;
2309 nCol = n;
2310 }
2311 assert(nRow && nCol);
2312 return inferMMAType(type, frag, nRow, nCol, context);
2313}
2314
2315LogicalResult NVVM::WMMALoadOp::verify() {
2316 unsigned addressSpace =
2317 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2318 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2319 addressSpace != NVVMMemorySpace::Shared)
2320 return emitOpError("expected source pointer in memory "
2321 "space 0, 1, 3");
2322
2323 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2324 getEltype(), getFrag()) == 0)
2325 return emitOpError() << "invalid attribute combination";
2326 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2327 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
2328 // Special case for f64 fragments
2329 Type f64Ty = Float64Type::get(getContext());
2330 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2331 if (getType() != f64Ty)
2332 return emitOpError("expected destination type to be f64");
2333 return success();
2334 }
2335 // Everything else is a struct
2336 Type dstType = LLVM::LLVMStructType::getLiteral(
2337 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
2338 if (getType() != dstType)
2339 return emitOpError("expected destination type is a structure of ")
2340 << typeInfo.second << " elements of type " << typeInfo.first;
2341 return success();
2342}
2343
2344LogicalResult NVVM::WMMAStoreOp::verify() {
2345 unsigned addressSpace =
2346 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2347 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2348 addressSpace != NVVMMemorySpace::Shared)
2349 return emitOpError("expected operands to be a source pointer in memory "
2350 "space 0, 1, 3");
2351
2352 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2353 getEltype()) == 0)
2354 return emitOpError() << "invalid attribute combination";
2355 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2356 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2357 if (getArgs().size() != typeInfo.second)
2358 return emitOpError() << "expected " << typeInfo.second << " data operands";
2359 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
2360 return operands.getType() != typeInfo.first;
2361 }))
2362 return emitOpError() << "expected data operands of type " << typeInfo.first;
2363 return success();
2364}
2365
2366LogicalResult NVVM::WMMAMmaOp::verify() {
2367 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
2368 getLayoutB(), getEltypeA(),
2369 getEltypeB()) == 0)
2370 return emitOpError() << "invalid attribute combination";
2371 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
2372 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
2373 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
2374 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
2375 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
2376 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2377 SmallVector<Type, 32> arguments;
2378 arguments.append(typeInfoA.second, typeInfoA.first);
2379 arguments.append(typeInfoB.second, typeInfoB.first);
2380 arguments.append(typeInfoC.second, typeInfoC.first);
2381 unsigned numArgs = arguments.size();
2382 if (getArgs().size() != numArgs)
2383 return emitOpError() << "expected " << numArgs << " arguments";
2384 for (unsigned i = 0; i < numArgs; i++) {
2385 if (getArgs()[i].getType() != arguments[i])
2386 return emitOpError() << "expected argument " << i << " to be of type "
2387 << arguments[i];
2388 }
2389 Type dstType = LLVM::LLVMStructType::getLiteral(
2390 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
2391 if (getType() != dstType)
2392 return emitOpError("expected destination type is a structure of ")
2393 << typeInfoC.second << " elements of type " << typeInfoC.first;
2394 return success();
2395}
2396
2397LogicalResult NVVM::LdMatrixOp::verify() {
2398 uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
2399 if (m == 8 && n == 8) {
2400 if (num != 1 && num != 2 && num != 4) {
2401 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
2402 "matrix");
2403 }
2404 if (getEltType() != LdStMatrixEltType::B16) {
2405 return emitOpError("expected element type to be b16 for 8x8 matrix");
2406 }
2407 } else if (m == 8 && n == 16) {
2408 if (num != 1 && num != 2 && num != 4) {
2409 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
2410 "matrix");
2411 }
2412 if (getLayout() != MMALayout::row) {
2413 return emitOpError("expected layout to be row for 8x16 matrix");
2414 }
2415 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2416 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2417 return emitOpError("expected element type to be b8x16.b4x16_p64 or "
2418 "b8x16.b6x16_p32 for 8x16 matrix");
2419 }
2420 } else if (m == 16 && n == 16) {
2421 if (num != 1 && num != 2) {
2422 return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
2423 "matrix");
2424 }
2425 if (getLayout() != MMALayout::col) {
2426 return emitOpError("expected layout to be col for 16x16 matrix");
2427 }
2428 if (getEltType() != LdStMatrixEltType::B8 &&
2429 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2430 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2431 return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
2432 "b8x16.b6x16_p32 for 16x16 matrix");
2433 }
2434 } else {
2435 return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
2436 }
2437
2438 Type i32 = IntegerType::get(getContext(), 32);
2439 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2440 if (numElements == 1 && getType() != i32)
2441 return emitOpError("expected destination type is i32");
2442 if (numElements == 2 || numElements == 4) {
2443 Type dstType = LLVM::LLVMStructType::getLiteral(
2444 getContext(), SmallVector<Type>(numElements, i32));
2445 if (getType() != dstType)
2446 return emitOpError("expected destination type is a structure of ")
2447 << numElements << " elements of type i32";
2448 }
2449
2450 return success();
2451}
2452
2453LogicalResult NVVM::StMatrixOp::verify() {
2454 int numMatrix = getSources().size();
2455 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2456 return emitOpError("expected num attribute to be 1, 2 or 4");
2457
2458 int m = getShape().getM(), n = getShape().getN();
2459 if (m == 8 && n == 8) {
2460 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2461 return emitOpError("expected element type to be B16 for 8x8 matrix");
2462 }
2463 } else if (m == 16 && n == 8) {
2464 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2465 return emitOpError("expected element type to be B8 for 16x8 matrix");
2466 }
2467 if (getLayout() != NVVM::MMALayout::col) {
2468 return emitOpError("expected layout to be col for 16x8 matrix");
2469 }
2470 } else {
2471 return emitOpError("expected shape to be 8x8 or 16x8");
2472 }
2473
2474 return success();
2475}
2476
2477static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
2478 if (typeA == NVVM::WGMMATypes::tf32)
2479 return 8;
2480 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2481 return 16;
2482 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2483 return 32;
2484 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2485 return 32;
2486 if (typeA == NVVM::WGMMATypes::b1)
2487 return 256;
2488 return failure();
2489}
2490
2491static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
2492 NVVM::WGMMATypes typeA,
2493 NVVM::WGMMATypes typeB) {
2494 switch (typeA) {
2495 case NVVM::WGMMATypes::f16:
2496 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2497 typeB == NVVM::WGMMATypes::f16)
2498 return success();
2499 break;
2500 case NVVM::WGMMATypes::tf32:
2501 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2502 return success();
2503 break;
2504 case NVVM::WGMMATypes::u8:
2505 case NVVM::WGMMATypes::s8:
2506 if (typeD == NVVM::WGMMATypes::s32 &&
2507 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2508 return success();
2509 break;
2510 case NVVM::WGMMATypes::b1:
2511 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2512 return success();
2513 break;
2514 case NVVM::WGMMATypes::bf16:
2515 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2516 typeB == NVVM::WGMMATypes::bf16)
2517 return success();
2518 break;
2519 case NVVM::WGMMATypes::e4m3:
2520 case NVVM::WGMMATypes::e5m2:
2521 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2522 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2523 return success();
2524 break;
2525 case WGMMATypes::f32:
2526 case WGMMATypes::s32:
2527 llvm_unreachable("unsupported input types");
2528 break;
2529 }
2530 return failure();
2531}
2532
2533static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
2534 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
2535 72, 80, 88, 96, 104, 112, 120, 128,
2536 136, 144, 152, 160, 168, 176, 184, 192,
2537 200, 208, 216, 224, 232, 240, 248, 256};
2538 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
2539 80, 96, 112, 128, 144, 160,
2540 176, 192, 208, 224, 240, 256};
2541 switch (typeA) {
2542 case WGMMATypes::f16:
2543 case WGMMATypes::tf32:
2544 case WGMMATypes::bf16:
2545 case WGMMATypes::e4m3:
2546 case WGMMATypes::e5m2:
2547 if (llvm::is_contained(allowedN, sizeN))
2548 return success();
2549 break;
2550 case WGMMATypes::u8:
2551 case WGMMATypes::s8:
2552 case WGMMATypes::b1:
2553 if (llvm::is_contained(allowedNshort, sizeN))
2554 return success();
2555 break;
2556 case WGMMATypes::f32:
2557 case WGMMATypes::s32:
2558 llvm_unreachable("unsupported input types");
2559 break;
2560 }
2561 return failure();
2562}
2563
2564LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2565 Value outValue = getResults();
2566 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
2567 if (!stype)
2568 return emitOpError() << "expected results to be struct";
2569 int outputSize = stype.getBody().size();
2570 WGMMATypes typeD = getTypeD();
2571 WGMMATypes typeA = getTypeA();
2572 WGMMATypes typeB = getTypeB();
2573
2574 for (Type t : stype.getBody()) {
2575 if (t != stype.getBody().front())
2576 return emitOpError()
2577 << "all elements in struct must be same type but there is " << t;
2578 }
2579
2580 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2581 typeD != WGMMATypes::s32) {
2582 return emitOpError() << "does not support the given output type " << typeD;
2583 }
2584 if (typeD == WGMMATypes::s32 &&
2585 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2586 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
2587 }
2588
2589 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
2590 return emitOpError() << typeD << " += " << typeA << " * " << typeB
2591 << ", it is not supported.";
2592 }
2593
2594 // Check M
2595 if (getShape().getM() != 64)
2596 return emitOpError() << "shape 'm' must be 64";
2597
2598 // Check K
2599 FailureOr<int> allowedK = getAllowedSizeK(typeA);
2600 if (failed(allowedK) || allowedK.value() != getShape().getK())
2601 return emitOpError() << "shape 'k' must be " << allowedK.value()
2602 << " for input type " << typeA;
2603
2604 // Check N
2605 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
2606 return emitOpError() << "has input type " << typeA << " n is set to "
2607 << getShape().getN() << ", it is not supported.";
2608 }
2609
2610 // Check transpose (only available for f16/bf16)
2611 // Matrices A should be stored in row-major and B in column-major.
2612 // Only f16/bf16 matrices can be stored in either column-major or row-major
2613 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
2614 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2615 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2616 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2617 return emitOpError()
2618 << "given layouts layout_a = " << getLayoutA()
2619 << " and layout_b = " << getLayoutB() << " for input types " << typeA
2620 << " and " << typeB
2621 << " requires transpose. However, this is only supported for: "
2622 << MMATypes::f16 << " and " << MMATypes::bf16;
2623 }
2624
2625 // Check result registers
2626 int expectedOutput = 0;
2627 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2628 expectedOutput = getShape().getN() / 2;
2629 if (typeD == WGMMATypes::f16)
2630 expectedOutput = getShape().getN() / 4;
2631 if (outputSize != expectedOutput) {
2632 return emitOpError() << "results " << expectedOutput
2633 << ", however output struct has " << outputSize
2634 << " elements";
2635 }
2636 // Check satfinite (only available for s32 accumulator)
2637 if (typeD != WGMMATypes::s32 &&
2638 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2639 NVVM::MMAIntOverflow::satfinite) {
2640 return emitOpError()
2641 << " `satfinite` can be only used with s32 accumulator, however "
2642 "the current accumulator is "
2643 << typeD;
2644 }
2645
2646 return success();
2647}
2648
2649std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2650
2651 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
2652 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2653
2654 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2655
2656 int expectedOutputRegisters = 0;
2657 if (getTypeD() == WGMMATypes::f16)
2658 expectedOutputRegisters = getShape().getN() / 4;
2659 else
2660 expectedOutputRegisters = getShape().getN() / 2;
2661
2662 std::string ptx;
2663 llvm::raw_string_ostream ss(ptx);
2664
2665 ss << "{\n"
2666 ".reg .pred p;\n"
2667 "setp.ne.b32 p, $"
2668 << ((expectedOutputRegisters * 2) + 2)
2669 << ", 0;\n"
2670 "wgmma.mma_async.sync.aligned.m"
2671 << m << "n" << n << "k" << k << "." << outputTypeName << "." << getTypeA()
2672 << "." << getTypeB();
2673 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2674 NVVM::MMAIntOverflow::satfinite)
2675 ss << ".satfinite";
2676 ss << " {";
2677 int regCnt = 0;
2678 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2679 ss << "$" << regCnt;
2680 if (regCnt != expectedOutputRegisters - 1)
2681 ss << ", ";
2682 }
2683
2684 ss << "},";
2685 // Need to map read/write registers correctly.
2686 regCnt = (regCnt * 2);
2687 ss << " $" << (regCnt) << ","
2688 << " $" << (regCnt + 1) << ","
2689 << " p";
2690 if (getTypeD() != WGMMATypes::s32) {
2691 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
2692 }
2693 // Don't add transpose parameters unless needed.
2694 if (isF16) {
2695 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
2696 }
2697 ss << ";\n"
2698 << "}\n";
2699 return ptx;
2700}
2701
2702bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2703 RewriterBase &rewriter,
2704 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2705 &asmValues) {
2706 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2707 if (getResults())
2708 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
2709 if (getInouts())
2710 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
2711 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
2712 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
2713 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
2715 if (getTypeD() != WGMMATypes::s32) {
2716 asmValues.push_back(
2717 {makeConstantI32(rewriter,
2718 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2720 asmValues.push_back(
2721 {makeConstantI32(rewriter,
2722 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2724 }
2725 if (isF16) {
2726 asmValues.push_back(
2727 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
2729 asmValues.push_back(
2730 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
2732 }
2733 return true; // Has manual mapping
2734}
2735
2736LogicalResult NVVM::FenceProxyOp::verify() {
2737 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2738 return emitOpError() << "async_shared fence requires space attribute";
2739 }
2740 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2741 return emitOpError() << "only async_shared fence can have space attribute";
2742 }
2743 return success();
2744}
2745
2746LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2747 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2748 return emitOpError("uni-directional proxies only support generic for "
2749 "from_proxy attribute");
2750
2751 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2752 return emitOpError("uni-directional proxies only support tensormap "
2753 "for to_proxy attribute");
2754 return success();
2755}
2756
2757LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2758 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2759 return emitOpError("uni-directional proxies only support generic for "
2760 "from_proxy attribute");
2761
2762 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2763 return emitOpError("uni-directional proxies only support tensormap "
2764 "for to_proxy attribute");
2765 return success();
2766}
2767
2768LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2769 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2770 return emitOpError("only generic is support for from_proxy attribute");
2771
2772 if (getToProxy() != NVVM::ProxyKind::async)
2773 return emitOpError("only async is supported for to_proxy attribute");
2774 return success();
2775}
2776
2777LogicalResult NVVM::SetMaxRegisterOp::verify() {
2778 if (getRegCount() % 8)
2779 return emitOpError("new register size must be multiple of 8");
2780 if (getRegCount() < 24 || getRegCount() > 256)
2781 return emitOpError("new register size must be in between 24 to 256");
2782 return success();
2783}
2784
2785LogicalResult NVVM::BarrierOp::verify() {
2786 if (getNumberOfThreads() && !getBarrierId())
2787 return emitOpError(
2788 "barrier id is missing, it should be set between 0 to 15");
2789
2790 if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
2791 return emitOpError("reduction are only available when id is 0");
2792
2793 if ((getReductionOp() && !getReductionPredicate()) ||
2794 (!getReductionOp() && getReductionPredicate()))
2795 return emitOpError("reduction predicate and reduction operation must be "
2796 "specified together");
2797
2798 return success();
2799}
2800
2801LogicalResult NVVM::Tcgen05CpOp::verify() {
2802 auto mc = getMulticast();
2803
2804 using SH = Tcgen05CpShape;
2805 using MC = Tcgen05CpMulticast;
2806 switch (getShape()) {
2807 case SH::SHAPE_128x256b:
2808 case SH::SHAPE_128x128b:
2809 case SH::SHAPE_4x256b:
2810 if (mc != MC::NONE)
2811 return emitError("Invalid multicast type for tcgen05.cp Op");
2812 break;
2813 case SH::SHAPE_64x128b:
2814 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2815 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
2816 "warpx2_02_13 for tcgen05.cp Op");
2817 break;
2818 case SH::SHAPE_32x128b:
2819 if (mc != MC::WARPX4)
2820 return emitError(
2821 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2822 break;
2823 }
2824 return success();
2825}
2826
2827LogicalResult NVVM::MatchSyncOp::verify() {
2828 if (getKind() == NVVM::MatchSyncKind::all) {
2829 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2830 if (!type || type.getBody().size() != 2 ||
2831 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2832 return emitOpError("match.sync 'all' returns a two element struct with "
2833 "first element as i32 and second element as i1");
2834 }
2835 } else {
2836 if (!getType().isInteger(32)) {
2837 return emitOpError("match.sync 'any' returns an i32");
2838 }
2839 }
2840 return success();
2841}
2842
2843LogicalResult NVVM::VoteSyncOp::verify() {
2844 if (getKind() == NVVM::VoteSyncKind::ballot) {
2845 if (!getType().isInteger(32)) {
2846 return emitOpError("vote.sync 'ballot' returns an i32");
2847 }
2848 } else {
2849 if (!getType().isInteger(1)) {
2850 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
2851 }
2852 }
2853 return success();
2854}
2855
2856LogicalResult NVVM::PrefetchOp::verify() {
2857 using MemSpace = NVVM::NVVMMemorySpace;
2858 using CacheLevel = NVVM::PrefetchCacheLevel;
2859
2860 unsigned addressSpace =
2861 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
2862 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2863 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2864
2865 if (getTensormap() && cacheLevel)
2866 return emitOpError("cannot specify both tensormap and cache level");
2867
2868 if (getTensormap()) {
2869 if (addressSpace != MemSpace::Generic &&
2870 addressSpace != MemSpace::Constant) {
2871 return emitOpError(
2872 "prefetch tensormap requires a generic or constant pointer");
2873 }
2874
2875 if (evictPriority) {
2876 return emitOpError(
2877 "prefetch tensormap does not support eviction priority");
2878 }
2879
2880 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2881 return emitOpError(
2882 "in_param_space can only be specified for a generic pointer");
2883 }
2884
2885 } else if (cacheLevel) {
2886 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
2887 addressSpace != MemSpace::Local) {
2888 return emitOpError("prefetch to cache level requires a generic, global, "
2889 "or local pointer");
2890 }
2891
2892 if (getUniform()) {
2893 if (*cacheLevel != CacheLevel::L1) {
2894 return emitOpError(
2895 "unsupported cache level, the only supported uniform "
2896 "cache level is L1");
2897 }
2898
2899 if (addressSpace != MemSpace::Generic) {
2900 return emitOpError(
2901 "prefetch to uniform cache requires a generic pointer");
2902 }
2903 }
2904
2905 if (evictPriority) {
2906 if (*cacheLevel != CacheLevel::L2)
2907 return emitOpError(
2908 "cache eviction priority supported only for cache level L2");
2909
2910 if (addressSpace != MemSpace::Global)
2911 return emitOpError("cache eviction priority requires a global pointer");
2912
2913 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
2914 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
2915 return emitOpError(
2916 "unsupported cache eviction priority, only evict_last and "
2917 "evict_normal are supported");
2918 }
2919
2920 if (getPredicate())
2921 return emitOpError("predicate supported only on prefetch tensormap");
2922
2923 } else {
2924 return emitOpError(
2925 "requires specification of either cache level or tensormap");
2926 }
2927
2928 return success();
2929}
2930
2931LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
2932 switch (getQueryType()) {
2933 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2934 if (!getType().isInteger(1))
2935 return emitOpError("is_canceled query type returns an i1");
2936 break;
2937 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2938 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2939 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2940 if (!getType().isInteger(32)) {
2941 return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
2942 "get_first_cta_id_z query types return an i32");
2943 }
2944 break;
2945 }
2946 return success();
2947}
2948
2949LogicalResult NVVM::ReduxOp::verify() {
2950 mlir::Type reduxType = getType();
2951
2952 if (!reduxType.isF32()) {
2953 if (getAbs())
2954 return emitOpError("abs attribute is supported only for f32 type");
2955 if (getNan())
2956 return emitOpError("nan attribute is supported only for f32 type");
2957 }
2958
2959 NVVM::ReductionKind kind = getKind();
2960 switch (kind) {
2961 case NVVM::ReductionKind::ADD:
2962 case NVVM::ReductionKind::AND:
2963 case NVVM::ReductionKind::OR:
2964 case NVVM::ReductionKind::XOR:
2965 case NVVM::ReductionKind::MAX:
2966 case NVVM::ReductionKind::MIN:
2967 case NVVM::ReductionKind::UMAX:
2968 case NVVM::ReductionKind::UMIN:
2969 if (!reduxType.isInteger(32))
2970 return emitOpError("'")
2971 << kind << "' reduction kind unsupported with " << reduxType
2972 << " type. Only supported type is 'i32'.";
2973 break;
2974 case NVVM::ReductionKind::FMIN:
2975 case NVVM::ReductionKind::FMAX:
2976 if (!reduxType.isF32())
2977 return emitOpError("'")
2978 << kind << "' reduction kind unsupported with " << reduxType
2979 << " type. Only supported type is 'f32'.";
2980 break;
2981 }
2982
2983 return success();
2984}
2985
2986LogicalResult NVVM::TensormapReplaceOp::verify() {
2987 auto ord = getOrd();
2988 Value newVal = getNewValue();
2989 auto newValAttr = getNewValueAttr();
2990 auto fieldName = stringifyEnum(getField());
2991
2992 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
2993 NVVM::TensormapField::GLOBAL_DIM,
2994 NVVM::TensormapField::GLOBAL_STRIDE,
2995 NVVM::TensormapField::ELEMENT_STRIDE},
2996 getField()))
2997 return emitOpError("ordinal is not supported for ")
2998 << fieldName << " field";
2999
3000 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3001 return llvm::Twine("new_value must be specified and must be an " + type +
3002 " for " + llvm::Twine(fieldName) + " field")
3003 .str();
3004 };
3005
3006 auto invalidNewValAttr = [&]() -> std::string {
3007 return (llvm::Twine(
3008 "new_value_attr must be specified and must be a valid ") +
3009 llvm::Twine(fieldName) + " attribute for " + fieldName + " field")
3010 .str();
3011 };
3012
3013 switch (getField()) {
3014 case NVVM::TensormapField::GLOBAL_ADDRESS:
3015 if (!(newVal && newVal.getType().isInteger(64)))
3016 return emitOpError(invalidNewVal("i64"));
3017 break;
3018 case NVVM::TensormapField::RANK:
3019 if (!(newVal && newVal.getType().isInteger(32)))
3020 return emitOpError(invalidNewVal("i32"));
3021 break;
3022 case NVVM::TensormapField::GLOBAL_STRIDE:
3023 if (!ord)
3024 return emitOpError("ordinal is required for global_stride field");
3025 if (!(newVal && newVal.getType().isInteger(64)))
3026 return emitOpError(invalidNewVal("i64"));
3027 break;
3028 case NVVM::TensormapField::BOX_DIM:
3029 case NVVM::TensormapField::GLOBAL_DIM:
3030 case NVVM::TensormapField::ELEMENT_STRIDE:
3031 if (!ord)
3032 return emitOpError("ordinal is required for ")
3033 << stringifyEnum(getField()) << " field";
3034 if (!(newVal && newVal.getType().isInteger(32)))
3035 return emitOpError(invalidNewVal("i32"));
3036 break;
3037 case NVVM::TensormapField::ELEMTYPE:
3038 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3039 return emitOpError(invalidNewValAttr());
3040 break;
3041 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3042 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3043 return emitOpError(invalidNewValAttr());
3044 break;
3045 case NVVM::TensormapField::SWIZZLE_MODE:
3046 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3047 return emitOpError(invalidNewValAttr());
3048 break;
3049 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3050 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3051 return emitOpError(invalidNewValAttr());
3052 break;
3053 case NVVM::TensormapField::FILL_MODE:
3054 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3055 return emitOpError(invalidNewValAttr());
3056 break;
3057 }
3058
3059 return success();
3060}
3061
3062template <typename OpType>
3063static LogicalResult verifyAddSubFOp(OpType op) {
3064 mlir::NVVM::FPRoundingMode rndMode = op.getRnd();
3065 mlir::NVVM::SaturationMode satMode = op.getSat();
3066 bool isFTZ = op.getFtz();
3067
3068 mlir::Type opType = op.getRes().getType();
3069 mlir::Type opBaseType = isa<VectorType>(opType)
3070 ? cast<VectorType>(opType).getElementType()
3071 : opType;
3072
3073 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3074 return op.emitOpError("FTZ and saturation are not supported for "
3075 "additions/subtractions involving f64 type");
3076
3077 if (opBaseType.isF16() && !(rndMode == NVVM::FPRoundingMode::RN ||
3078 rndMode == NVVM::FPRoundingMode::NONE))
3079 return op.emitOpError("only RN rounding mode is supported for f16 and "
3080 "vector<2xf16> additions/subtractions");
3081
3082 if (opBaseType.isBF16()) {
3083 if (rndMode != NVVM::FPRoundingMode::RN &&
3084 rndMode != NVVM::FPRoundingMode::NONE)
3085 return op.emitOpError("only RN rounding mode is supported for bf16 and "
3086 "vector<2xbf16> additions/subtractions");
3087 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3088 return op.emitOpError("FTZ and saturation are not supported for bf16 and "
3089 "vector<2xbf16> additions/subtractions");
3090 }
3091
3092 // FIXME: This is a temporary check disallowing lowering to add.rn.ftz.f16(x2)
3093 // PTX instructions since the corresponding LLVM intrinsic is missing. This
3094 // should be removed once the intrinsics for f16 addition (with FTZ only) are
3095 // available.
3096 if (opBaseType.isF16() && isFTZ && satMode == NVVM::SaturationMode::NONE)
3097 return op.emitOpError("FTZ with no saturation is not supported for f16 and "
3098 "vector<2xf16> additions/subtractions");
3099
3100 return success();
3101}
3102
3103LogicalResult NVVM::AddFOp::verify() { return verifyAddSubFOp<AddFOp>(*this); }
3104
3105LogicalResult NVVM::SubFOp::verify() { return verifyAddSubFOp<SubFOp>(*this); }
3106
3107LogicalResult NVVM::FmaOp::verify() {
3108 auto opType = getRes().getType();
3109 mlir::NVVM::FPRoundingMode rndMode = getRnd();
3110 mlir::NVVM::SaturationMode satMode = getSat();
3111 bool isFTZ = getFtz();
3112 bool isRelu = getRelu();
3113 bool hasOOB = getOob();
3114
3115 auto getBaseFType = [](Type type) -> Type {
3116 if (isa<VectorType>(type))
3117 return cast<VectorType>(type).getElementType();
3118 return type;
3119 };
3120
3121 auto opBaseType = getBaseFType(opType);
3122
3123 if (rndMode == NVVM::FPRoundingMode::NONE)
3124 return emitOpError("rounding mode must be specified");
3125
3126 if (isRelu && satMode == NVVM::SaturationMode::SAT)
3127 return emitOpError("relu and saturation are not supported together");
3128
3129 if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
3130 return emitOpError("oob is not supported with saturation or FTZ");
3131
3132 if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
3133 return emitOpError("relu and oob are only supported for f16 and bf16");
3134
3135 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3136 return emitOpError("FTZ and saturation are not supported for f64 type");
3137
3138 if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
3139 return emitOpError(
3140 "only RN rounding mode is supported for f16 and vector<2xf16>");
3141
3142 if (opBaseType.isBF16()) {
3143 if (rndMode != NVVM::FPRoundingMode::RN)
3144 return emitOpError(
3145 "only RN rounding mode is supported for bf16 and vector<2xbf16>");
3146 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3147 return emitOpError(
3148 "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
3149 }
3150
3151 return success();
3152}
3153
3154/// Packs the given `field` into the `result`.
3155/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
3156static llvm::Value *
3157packValInto64Bits(llvm::IRBuilderBase &builder,
3158 llvm::Value *result, // the `result` (unset bits are zero)
3159 llvm::Value *field, // `field` to pack into `result`
3160 unsigned sizeInBits, // Size of `field` in bits
3161 unsigned start) { // Starting bit within `result`
3162 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3163
3164 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3165 if (mask != 0xffffffffu)
3166 field = builder.CreateAnd(field, builder.getInt32(mask));
3167
3168 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3169 field = builder.CreateShl(field, start);
3170
3171 return builder.CreateOr(result, field);
3172}
3173
3174void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
3176 llvm::IRBuilderBase &builder) {
3177 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3178 llvm::Value *smemDesc = builder.getInt64(0);
3179
3180 smemDesc = packValInto64Bits(builder, smemDesc,
3181 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
3182 smemDesc = packValInto64Bits(
3183 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3184 smemDesc = packValInto64Bits(
3185 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3186
3187 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
3188 smemDesc = packValInto64Bits(builder, smemDesc,
3189 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
3190 smemDesc = packValInto64Bits(
3191 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3192 smemDesc = packValInto64Bits(builder, smemDesc,
3193 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
3194
3195 mt.mapValue(thisOp.getRes()) = smemDesc;
3196}
3197
3198//===----------------------------------------------------------------------===//
3199// getPtx methods
3200//===----------------------------------------------------------------------===//
3201
3202std::string NVVM::MBarrierInitOp::getPtx() {
3203 bool isShared = isPtrInSharedCTASpace(getAddr());
3204 return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
3205 : std::string("mbarrier.init.b64 [%0], %1;");
3206}
3207
3208std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3209 bool isShared = isPtrInSharedCTASpace(getAddr());
3210 return isShared
3211 ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3212 : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3213}
3214
3215std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3216 bool isShared = isPtrInSharedCTASpace(getAddr());
3217 llvm::StringRef space = isShared ? ".shared" : "";
3218
3219 return llvm::formatv("{\n\t"
3220 ".reg .pred P1; \n\t"
3221 "LAB_WAIT: \n\t"
3222 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3223 "@P1 bra.uni DONE; \n\t"
3224 "bra.uni LAB_WAIT; \n\t"
3225 "DONE: \n\t"
3226 "}",
3227 space);
3228}
3229
3230//===----------------------------------------------------------------------===//
3231// Canonicalization patterns
3232//===----------------------------------------------------------------------===//
3233
3236
3237 LogicalResult matchAndRewrite(SubFOp op,
3238 PatternRewriter &rewriter) const override {
3239 Location loc = op.getLoc();
3240 Value negRhs =
3241 LLVM::FNegOp::create(rewriter, loc, op.getRhs().getType(), op.getRhs());
3242
3243 rewriter.replaceOpWithNewOp<AddFOp>(op, op.getType(), op.getLhs(), negRhs,
3244 op.getRnd(), op.getSat(), op.getFtz());
3245 return success();
3246 }
3247};
3248
3249void SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3250 MLIRContext *context) {
3251 patterns.add<ConvertFsubToFnegFadd>(context);
3252}
3253
3254//===----------------------------------------------------------------------===//
3255// getIntrinsicID/getIntrinsicIDAndArgs methods
3256//===----------------------------------------------------------------------===//
3257
3258mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
3259 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3260 auto thisOp = cast<NVVM::BarrierOp>(op);
3261 llvm::Value *barrierId = thisOp.getBarrierId()
3262 ? mt.lookupValue(thisOp.getBarrierId())
3263 : builder.getInt32(0);
3264 llvm::Intrinsic::ID id;
3265 llvm::SmallVector<llvm::Value *> args = {barrierId};
3266 if (thisOp.getNumberOfThreads()) {
3267 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3268 args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
3269 } else if (thisOp.getReductionOp()) {
3270 switch (*thisOp.getReductionOp()) {
3271 case NVVM::BarrierReduction::AND:
3272 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3273 break;
3274 case NVVM::BarrierReduction::OR:
3275 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3276 break;
3277 case NVVM::BarrierReduction::POPC:
3278 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3279 break;
3280 }
3281 args.push_back(builder.CreateICmpNE(
3282 mt.lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3283 } else {
3284 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3285 }
3286
3287 return {id, std::move(args)};
3288}
3289
3291PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3292 llvm::IRBuilderBase &builder) {
3293 auto thisOp = cast<NVVM::PMEventOp>(op);
3294 llvm::Type *i16Ty = llvm::Type::getInt16Ty(mt.getLLVMContext());
3295
3296 // With event-id, mask is generated as (1 << event-id)
3297 llvm::Value *maskVal;
3298 if (auto eventAttr = thisOp.getEventIdAttr()) {
3299 uint16_t mask = static_cast<uint16_t>(1u << eventAttr.getInt());
3300 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3301 } else {
3302 maskVal =
3303 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3304 }
3305
3306 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3307}
3308
3309mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
3310 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3311 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3312 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3313 llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3314 : llvm::Intrinsic::nvvm_mbarrier_init;
3315
3316 // Fill the Intrinsic Args
3318 args.push_back(mt.lookupValue(thisOp.getAddr()));
3319 args.push_back(mt.lookupValue(thisOp.getCount()));
3320
3321 return {id, std::move(args)};
3322}
3323
3324mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
3325 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3326 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3327 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3328 llvm::Intrinsic::ID id = isShared
3329 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3330 : llvm::Intrinsic::nvvm_mbarrier_inval;
3331
3332 return {id, {mt.lookupValue(thisOp.getAddr())}};
3333}
3334
3335mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
3336 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3337 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3338
3339 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3340 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3341 // bit-0: Space
3342 // bit-1: Scope
3343 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3344
3345 static constexpr llvm::Intrinsic::ID IDs[] = {
3346 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3347 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3348 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3349 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3350
3351 // Fill the Intrinsic Args
3353 args.push_back(mt.lookupValue(thisOp.getAddr()));
3354 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3355
3356 return {IDs[index], std::move(args)};
3357}
3358
3359mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
3360 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3361 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3362
3363 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3364 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3365 // bit-0: Space
3366 // bit-1: Scope
3367 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3368
3369 static constexpr llvm::Intrinsic::ID IDs[] = {
3370 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3371 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3372 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3373 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3374
3375 // Fill the Intrinsic Args
3377 args.push_back(mt.lookupValue(thisOp.getAddr()));
3378 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3379
3380 return {IDs[index], std::move(args)};
3381}
3382
3383mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
3384 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3385 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3386
3387 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3388 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3389 // bit-0: Space
3390 // bit-1: Scope
3391 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3392
3393 static constexpr llvm::Intrinsic::ID IDs[] = {
3394 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3395 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3396 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3397 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3398 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3399 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3400 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3401 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3402 llvm::Intrinsic::
3403 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3404 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3405
3406 // Tidy-up the Intrinsic Args
3407 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3408 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3409 if (needCast)
3410 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3411
3412 // We have the most basic mbarrier.arrive supported on sm_80.
3413 // It supports: Space=cta, scope=cta, No relaxed, No explicit count.
3414 // So, only for this combination use the legacy intrinsic.
3415 bool hasCount = static_cast<bool>(thisOp.getCount());
3416 if (!hasCount &&
3417 (id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3418 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3419
3420 // When count is not explicitly specified, the default is 1.
3421 llvm::LLVMContext &ctx = mt.getLLVMContext();
3422 llvm::Value *count =
3423 hasCount ? mt.lookupValue(thisOp.getCount())
3424 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3425 return {id, {mbar, count}};
3426}
3427
3428mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
3429 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3430 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3431
3432 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3433 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3434 // bit-0: Space
3435 // bit-1: Scope
3436 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3437
3438 static constexpr llvm::Intrinsic::ID IDs[] = {
3439 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3440 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3441 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3442 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3443 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3444 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3445 llvm::Intrinsic::
3446 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3447 llvm::Intrinsic::
3448 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3449 llvm::Intrinsic::
3450 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3451 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3452
3453 // Tidy-up the Intrinsic Args
3454 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3455 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3456 if (needCast)
3457 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3458
3459 // When count is not explicitly specified, the default is 1.
3460 llvm::LLVMContext &ctx = mt.getLLVMContext();
3461 bool hasCount = static_cast<bool>(thisOp.getCount());
3462 llvm::Value *count =
3463 hasCount ? mt.lookupValue(thisOp.getCount())
3464 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3465
3466 return {id, {mbar, count}};
3467}
3468
3469bool MBarrierArriveExpectTxOp::getAsmValues(
3470 RewriterBase &rewriter,
3471 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3472 &asmValues) {
3473 // Add all the operands but not the attrs to the asmValues list.
3474 // The attrs here are used to generate the right variants for
3475 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3476 for (auto val : getOperands())
3477 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3478
3479 return false;
3480}
3481
3482mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
3483 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3484 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3485
3486 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3487 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3488 // bit-0: Space
3489 // bit-1: Scope
3490 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3491
3492 // clang-format off
3493 static constexpr llvm::Intrinsic::ID IDs[] = {
3494 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3495 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3496 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3497 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3498 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3499 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3500 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3501 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3502 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3503 // clang-format on
3504 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3505
3506 // Tidy-up the Intrinsic Args
3507 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3508 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3509 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3510 if (needCast)
3511 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3512
3513 return {id, {mbar, txcount}};
3514}
3515
3516mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
3517 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3518 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3519
3520 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3521 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3522 // bit-0: Space
3523 // bit-1: Scope
3524 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3525
3526 // clang-format off
3527 static constexpr llvm::Intrinsic::ID IDs[] = {
3528 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3529 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3530 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3531 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3532 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3533 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3534 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3535 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3536 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3537 // clang-format on
3538 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3539
3540 // Tidy-up the Intrinsic Args
3541 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3542 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3543 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3544 if (needCast)
3545 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3546
3547 return {id, {mbar, txcount}};
3548}
3549
3550mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
3551 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3552 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3553 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3554 llvm::Intrinsic::ID id =
3555 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3556 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3557 // Fill the Intrinsic Args
3559 args.push_back(mt.lookupValue(thisOp.getAddr()));
3560 args.push_back(mt.lookupValue(thisOp.getCount()));
3561
3562 return {id, std::move(args)};
3563}
3564
3565mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
3566 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3567 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3568 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3569 llvm::Intrinsic::ID id =
3570 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3571 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3572 // Fill the Intrinsic Args
3574 args.push_back(mt.lookupValue(thisOp.getAddr()));
3575 args.push_back(mt.lookupValue(thisOp.getCount()));
3576
3577 return {id, std::move(args)};
3578}
3579
3580mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
3581 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3582 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3583 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3584 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3585 // bit-0: isPhaseParity
3586 // bit-1: Scope
3587 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3588
3589 // clang-format off
3590 static constexpr llvm::Intrinsic::ID IDs[] = {
3591 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3592 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3593 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3594 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3595 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3596 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3597 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3598 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3599 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3600 // clang-format on
3601 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3602
3603 // Tidy-up the Intrinsic Args
3604 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3605 llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
3606 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3607 if (needCast)
3608 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3609
3610 return {id, {mbar, input}};
3611}
3612
3613mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
3614 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3615 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3616 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3617 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3618 bool hasTicks = static_cast<bool>(thisOp.getTicks());
3619 // bit-0: isPhaseParity
3620 // bit-1: Scope
3621 // bit-2: hasTicks
3622 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3623 (isPhaseParity ? 1 : 0);
3624
3625 // clang-format off
3626 static constexpr llvm::Intrinsic::ID IDs[] = {
3627 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3628 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3629 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3630 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3631 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3632 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3633 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3634 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3635 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3636 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3637 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3638 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3639 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3640 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3641 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3642 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3643 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3644 // clang-format on
3645 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3646
3647 // Tidy-up the mbarrier pointer
3648 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3649 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3650 if (needCast)
3651 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3652
3653 // Fill the Intrinsic Args
3655 args.push_back(mbar);
3656 args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
3657 if (hasTicks)
3658 args.push_back(mt.lookupValue(thisOp.getTicks()));
3659
3660 return {id, std::move(args)};
3661}
3662
3663mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
3664 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3665 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3666 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3667
3668 llvm::Intrinsic::ID id;
3669 if (thisOp.getNoinc()) {
3670 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3671 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3672 } else {
3673 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3674 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3675 }
3676
3677 return {id, {mt.lookupValue(thisOp.getAddr())}};
3678}
3679
3680#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3681 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3682
3683#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3684 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
3685
3686llvm::Intrinsic::ID
3687CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3689 llvm::Intrinsic::ID id;
3690
3691 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
3692 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
3693 switch (cpAsyncOp.getSize()) {
3694 case 4:
3695 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
3696 break;
3697 case 8:
3698 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
3699 break;
3700 case 16:
3701 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
3702 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
3703 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
3704 break;
3705 default:
3706 llvm_unreachable("Invalid copy size in CpAsyncOp.");
3707 }
3708
3709 // Fill the Intrinsic Args
3710 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
3711 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
3712 if (hasCpSize)
3713 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
3714
3715 return id;
3716}
3717
3718mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
3719 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3720 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
3722 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
3723
3724 // Fill the Intrinsic Args
3725 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3726 args.push_back(mt.lookupValue(thisOp.getSize()));
3727
3728 mlir::Value cacheHint = thisOp.getL2CacheHint();
3729 const bool hasCacheHint = static_cast<bool>(cacheHint);
3730 llvm::Value *i64Unused =
3731 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3732 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3733 args.push_back(builder.getInt1(hasCacheHint));
3734
3735 return {id, std::move(args)};
3736}
3737
3738mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3739 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3740 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
3742
3743 // Fill the Intrinsic Args: dst, mbar, src, size.
3744 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3745 args.push_back(mt.lookupValue(thisOp.getMbar()));
3746 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3747 args.push_back(mt.lookupValue(thisOp.getSize()));
3748
3749 // Multicast mask for shared::cluster only, if available.
3750 mlir::Value multicastMask = thisOp.getMulticastMask();
3751 const bool hasMulticastMask = static_cast<bool>(multicastMask);
3752 const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
3753 if (!isSharedCTA) {
3754 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
3755 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
3756 : i16Unused);
3757 }
3758
3759 // Cache hint, if available.
3760 mlir::Value cacheHint = thisOp.getL2CacheHint();
3761 const bool hasCacheHint = static_cast<bool>(cacheHint);
3762 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
3763 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3764
3765 // Flag arguments for multicast and cachehint.
3766 if (!isSharedCTA)
3767 args.push_back(builder.getInt1(hasMulticastMask));
3768 args.push_back(builder.getInt1(hasCacheHint));
3769
3770 llvm::Intrinsic::ID id =
3771 isSharedCTA
3772 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
3773 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
3774
3775 return {id, std::move(args)};
3776}
3777
3778mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3779 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3780 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
3782 llvm::Intrinsic::ID id =
3783 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
3784
3785 // Fill the Intrinsic Args
3786 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3787 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3788 args.push_back(mt.lookupValue(thisOp.getSize()));
3789
3790 mlir::Value cacheHint = thisOp.getL2CacheHint();
3791 const bool hasCacheHint = static_cast<bool>(cacheHint);
3792 llvm::Value *i64Unused =
3793 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3794 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3795 args.push_back(builder.getInt1(hasCacheHint));
3796
3797 // Choose the bytemask variant
3798 if (mlir::Value byteMask = thisOp.getByteMask()) {
3799 args.push_back(mt.lookupValue(byteMask));
3800 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
3801 }
3802
3803 return {id, std::move(args)};
3804}
3805
3806bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
3807 RewriterBase &rewriter,
3808 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3809 &asmValues) {
3810 // Add all the operands but not the attrs to the asmValues list.
3811 // The attrs here are used to generate the right variants for
3812 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3813 for (auto val : getOperands())
3814 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3815
3816 return false;
3817}
3818
3820CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3821 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3822 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
3823 const bool isCTAOnly = thisOp.getIsCTAOnly();
3825
3826 // Fill the Intrinsic Args
3827 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3828 args.push_back(mt.lookupValue(thisOp.getMbar()));
3829 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3830
3831 // Coordinates and im2col-offsets
3832 for (mlir::Value v : thisOp.getCoordinates())
3833 args.push_back(mt.lookupValue(v));
3834 for (mlir::Value v : thisOp.getIm2colOffsets())
3835 args.push_back(mt.lookupValue(v));
3836
3837 // MulticastMask, if available
3838 mlir::Value mcMask = thisOp.getMulticastMask();
3839 const bool hasMC = static_cast<bool>(mcMask);
3840 llvm::Value *i16Zero =
3841 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
3842
3843 // CacheHint, if available
3844 mlir::Value cacheHint = thisOp.getL2CacheHint();
3845 const bool hasCacheHint = static_cast<bool>(cacheHint);
3846 llvm::Value *i64Zero =
3847 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3848
3849 // Flag argument CTAGroup
3850 // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
3851 // Hence, the +1 to getGroup().
3852 const int32_t val =
3853 thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
3854 llvm::Value *cg =
3855 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
3856
3857 if (!isCTAOnly) {
3858 // For shared::cluster, all the arguments that we build are applicable.
3859 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
3860 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
3861 args.push_back(builder.getInt1(hasMC));
3862 args.push_back(builder.getInt1(hasCacheHint));
3863 args.push_back(cg);
3864 } else {
3865 // For shared::cta, only cache-hint is applicable.
3866 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
3867 args.push_back(builder.getInt1(hasCacheHint));
3868 }
3869
3870 constexpr size_t numDims = 5; // 1D to 5D
3871 constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
3872 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
3873 using TableTy = std::array<rowTy, numModes>;
3874 static constexpr TableTy IDTable{
3875 {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
3876 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
3877 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
3878 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
3879 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
3881 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
3882 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
3883 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
3885 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
3886 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
3887 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
3889 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
3890 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
3891 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
3893 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
3894
3895 static constexpr TableTy IDTableCTA{
3896 {{notIntrinsic,
3897 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
3898 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
3899 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
3900 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
3901 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
3903 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
3904 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
3905 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
3907 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
3908 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
3909 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
3911 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
3912 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
3913 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
3915 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
3916
3917 static_assert(
3918 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
3919 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
3920 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
3921 size_t mode = static_cast<size_t>(thisOp.getMode());
3922 size_t dim = thisOp.getCoordinates().size();
3923 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
3924 assert(id != notIntrinsic &&
3925 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
3926
3927 return {id, std::move(args)};
3928}
3929
3930mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
3931 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3932 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
3934
3935 // Fill the Intrinsic Args
3936 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3937
3938 for (auto v : thisOp.getCoordinates())
3939 args.push_back(mt.lookupValue(v));
3940 for (auto v : thisOp.getIm2colOffsets())
3941 args.push_back(mt.lookupValue(v));
3942
3943 mlir::Value cacheHint = thisOp.getL2CacheHint();
3944 const bool hasCacheHint = static_cast<bool>(cacheHint);
3945 llvm::Value *i64Unused =
3946 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3947 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3948 args.push_back(builder.getInt1(hasCacheHint));
3949
3950 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3951 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3952 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
3953 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
3954 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
3955 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
3956 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
3957 {NI, NI, NI,
3958 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
3959 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
3960 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
3961 {NI, NI, NI,
3962 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
3963 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
3964 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
3965 {NI, NI, NI,
3966 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
3967 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
3968 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
3969 {NI, NI, NI, NI, NI,
3970 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
3971
3972 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
3973 "TMALoadModes must match number of rows in IDTable");
3974 size_t mode = static_cast<size_t>(thisOp.getMode());
3975 size_t dim = thisOp.getCoordinates().size();
3976 llvm::Intrinsic::ID id = IDTable[mode][dim];
3977 if (id == llvm::Intrinsic::not_intrinsic)
3978 llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
3979
3980 return {id, std::move(args)};
3981}
3982
3984CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3985 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3986 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
3988
3989 // Fill the Intrinsic Args
3990 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3991 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3992
3993 for (auto v : thisOp.getCoordinates())
3994 args.push_back(mt.lookupValue(v));
3995
3996 mlir::Value cacheHint = thisOp.getL2CacheHint();
3997 const bool hasCacheHint = static_cast<bool>(cacheHint);
3998 llvm::Value *i64Unused =
3999 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4000 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4001 args.push_back(builder.getInt1(hasCacheHint));
4002
4003 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4004 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4005 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
4006 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
4007 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
4008 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
4009 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
4010 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
4011 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
4012 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
4013 {NI, NI, NI, NI, NI,
4014 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
4015
4016 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
4017 "TMAStoreModes must match number of rows in IDTable");
4018 size_t mode = static_cast<size_t>(thisOp.getMode());
4019 size_t dim = thisOp.getCoordinates().size();
4020 llvm::Intrinsic::ID id = IDTable[mode][dim];
4021 if (id == llvm::Intrinsic::not_intrinsic)
4022 llvm_unreachable(
4023 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
4024
4025 return {id, std::move(args)};
4026}
4027
4028NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
4029 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4030 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
4031 llvm::LLVMContext &ctx = mt.getLLVMContext();
4032
4034
4035 // Arguments to the intrinsic:
4036 // shared_mem_ptr, tmaDesc, tensorDims
4037 // cache_hint(if applicable) and flag(boolean)
4038 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4039 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4040
4041 for (Value v : thisOp.getCoordinates())
4042 args.push_back(mt.lookupValue(v));
4043
4044 mlir::Value cacheHint = thisOp.getL2CacheHint();
4045 const bool hasCacheHint = static_cast<bool>(cacheHint);
4046 llvm::Value *i64ZeroValue =
4047 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
4048 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
4049 args.push_back(builder.getInt1(hasCacheHint));
4050
4051 const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
4052
4053 constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
4054 constexpr unsigned numLayouts = 2; // TILE, IM2COL
4055 constexpr unsigned maxDim = 5; // 1D to 5D
4056 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
4057 using layoutTable = std::array<row, numLayouts>;
4058 using fullTable = std::array<layoutTable, numRedKinds>;
4059 static constexpr fullTable IDTable{
4060 {// RedTy::ADD
4061 {{{{notIntrinsic,
4062 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
4063 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
4064 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
4065 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
4066 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
4068 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
4069 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
4070 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
4071 // RedTy::MIN
4072 {{{{notIntrinsic,
4073 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
4074 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
4075 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
4076 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
4077 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
4079 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
4080 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
4081 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
4082 // RedTy::MAX
4083 {{{{notIntrinsic,
4084 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
4085 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
4086 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
4087 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
4088 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
4090 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
4091 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
4092 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
4093 // RedTy::INC
4094 {{{{notIntrinsic,
4095 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
4096 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
4097 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
4098 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
4099 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
4101 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
4102 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4103 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4104 // RedTy::DEC
4105 {{{{notIntrinsic,
4106 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4107 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4108 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4109 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4110 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4112 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4113 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4114 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4115 // RedTy::AND
4116 {{{{notIntrinsic,
4117 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4118 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4119 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4120 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4121 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4123 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4124 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4125 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4126 // RedTy::OR
4127 {{{{notIntrinsic,
4128 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4129 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4130 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4131 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4132 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4134 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4135 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4136 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4137 // RedTy::XOR
4138 {{{{notIntrinsic,
4139 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4140 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4141 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4142 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4143 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4145 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4146 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4147 llvm::Intrinsic::
4148 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4149
4150 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4151 "TMAReduxKinds must match number of rows in IDTable");
4152
4153 size_t redKind = static_cast<size_t>(thisOp.getRedKind());
4154 size_t mode = static_cast<size_t>(thisOp.getMode());
4155 size_t dim = thisOp.getCoordinates().size();
4156
4157 assert(redKind < IDTable.size() &&
4158 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4159 assert(mode < IDTable[redKind].size() &&
4160 "Invalid mode for CpAsyncBulkTensorReduceOp");
4161 assert(dim < IDTable[redKind][mode].size() &&
4162 "Invalid dim for CpAsyncBulkTensorReduceOp");
4163
4164 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4165
4166 assert(intrinsicID != notIntrinsic &&
4167 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4168
4169 return {intrinsicID, std::move(args)};
4170}
4171
4172#define _none
4173
4174#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4175 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4176 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4177
4178#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4179 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4180 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4181
4182llvm::Intrinsic::ID
4183ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4184 NVVM::SaturationMode sat, bool hasRelu) {
4185 using RndMode = NVVM::FPRoundingMode;
4186 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4187 switch (rnd) {
4188 case RndMode::RN:
4189 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
4190 case RndMode::RZ:
4191 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
4192 case RndMode::RNA:
4193 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
4194 default:
4195 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
4196 }
4197}
4198
4200ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4202 llvm::IRBuilderBase &builder) {
4204 args.push_back(mt.lookupValue(op.getA()));
4205 args.push_back(mt.lookupValue(op.getB()));
4206
4207 bool hasRelu = op.getRelu();
4208
4209 llvm::Intrinsic::ID intId =
4210 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4211 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4212
4213 return {intId, std::move(args)};
4214}
4215
4216#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4217 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4218 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4219
4220llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4221 bool hasRelu) {
4223 .Case([&](mlir::Float6E2M3FNType) {
4224 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
4225 })
4226 .Case([&](mlir::Float6E3M2FNType) {
4227 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
4228 })
4229 .Default([](mlir::Type) {
4230 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
4231 return llvm::Intrinsic::not_intrinsic;
4232 });
4233}
4234
4235#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4236 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4237 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4238
4239#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4240 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4241 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4242
4243llvm::Intrinsic::ID
4244ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4245 NVVM::SaturationMode sat, bool hasRelu) {
4246 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4247 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4248 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4249
4251 .Case([&](mlir::Float8E4M3FNType) {
4252 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
4253 })
4254 .Case([&](mlir::Float8E5M2Type) {
4255 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
4256 })
4257 .Case([&](mlir::Float8E8M0FNUType) {
4258 if (hasRoundingModeRZ)
4259 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
4260 else if (hasRoundingModeRP)
4261 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
4262
4263 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4264 })
4265 .Default([](mlir::Type) {
4266 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4267 return llvm::Intrinsic::not_intrinsic;
4268 });
4269}
4270
4271#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4272 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4273 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4274
4275llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4276 bool hasRelu) {
4278 .Case([&](mlir::Float8E4M3FNType) {
4279 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
4280 })
4281 .Case([&](mlir::Float8E5M2Type) {
4282 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
4283 })
4284 .Default([](mlir::Type) {
4285 llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
4286 return llvm::Intrinsic::not_intrinsic;
4287 });
4288}
4289
4290#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
4291 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
4292 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
4293
4294llvm::Intrinsic::ID
4295ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4296 NVVM::SaturationMode sat) {
4297 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4298 switch (rnd) {
4299 case NVVM::FPRoundingMode::RZ:
4300 return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
4301 case NVVM::FPRoundingMode::RP:
4302 return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
4303 default:
4304 llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
4305 }
4306}
4307
4308NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
4309 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4310 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4311
4312 bool hasRelu = curOp.getRelu();
4313
4314 llvm::Intrinsic::ID intId =
4316 .Case([&](Float8E4M3FNType type) {
4317 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4318 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4319 })
4320 .Case([&](Float8E5M2Type type) {
4321 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4322 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4323 })
4324 .Default([](mlir::Type type) {
4325 llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
4326 return llvm::Intrinsic::not_intrinsic;
4327 });
4328
4329 llvm::Value *packedI16 =
4330 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4331 llvm::Type::getInt16Ty(builder.getContext()));
4332
4333 return {intId, {packedI16}};
4334}
4335
4336NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
4337 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4338 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4339
4340 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4341 llvm::Value *packedI16 =
4342 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4343 llvm::Type::getInt16Ty(builder.getContext()));
4344
4345 return {intId, {packedI16}};
4346}
4347
4348NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
4349 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4350 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4351
4352 bool hasRelu = curOp.getRelu();
4353
4354 llvm::Intrinsic::ID intId =
4356 .Case([&](Float6E2M3FNType type) {
4357 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4358 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4359 })
4360 .Case([&](Float6E3M2FNType type) {
4361 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4362 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4363 })
4364 .Default([](mlir::Type type) {
4365 llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
4366 return llvm::Intrinsic::not_intrinsic;
4367 });
4368
4369 llvm::Value *packedI16 =
4370 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4371 llvm::Type::getInt16Ty(builder.getContext()));
4372
4373 return {intId, {packedI16}};
4374}
4375
4376NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
4377 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4378 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4379
4380 bool hasRelu = curOp.getRelu();
4381
4382 llvm::Intrinsic::ID intId =
4384 .Case([&](Float4E2M1FNType type) {
4385 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4386 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4387 })
4388 .Default([](mlir::Type type) {
4389 llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
4390 return llvm::Intrinsic::not_intrinsic;
4391 });
4392
4393 llvm::Value *extendedI16 =
4394 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
4395 llvm::Type::getInt16Ty(builder.getContext()));
4396
4397 return {intId, {extendedI16}};
4398}
4399
4400llvm::Intrinsic::ID
4401Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
4404 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4405 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4406 .getAddressSpace();
4407 bool isShared = as == NVVMMemorySpace::Shared;
4408 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4409
4410 llvm::Intrinsic::ID id;
4411 if (isShared) {
4412 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4413 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4414 } else {
4415 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4416 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4417 }
4418
4419 // Fill the Intrinsic Args
4420 args.push_back(mt.lookupValue(curOp.getAddr()));
4421 args.push_back(mt.lookupValue(curOp.getNCols()));
4422
4423 return id;
4424}
4425
4426llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4429 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4430 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4431 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4432 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4433
4434 // Fill the Intrinsic Args
4435 args.push_back(mt.lookupValue(curOp.getTaddr()));
4436 args.push_back(mt.lookupValue(curOp.getNCols()));
4437
4438 return id;
4439}
4440
4441#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4442 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4443 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4444
4445#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4446 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4447 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4448
4449llvm::Intrinsic::ID
4450Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
4453 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4454 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4455 .getAddressSpace();
4456 bool isShared = as == NVVMMemorySpace::Shared;
4457 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
4458 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4459
4460 llvm::Intrinsic::ID id =
4461 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
4462 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
4463
4464 // Fill the Intrinsic Args
4465 args.push_back(mt.lookupValue(curOp.getAddr()));
4466 if (hasMulticast)
4467 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
4468
4469 return id;
4470}
4471
4472#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4473 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4474
4475#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4476 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4477 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4478
4479#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4480 [&]() -> auto { \
4481 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4482 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4483 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4484 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4485 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4486 }()
4487
4489ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4491 llvm::IRBuilderBase &builder) {
4492 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4493 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4494 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4495 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4496 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4497 };
4498 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4499 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4500 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4501 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4502 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4503 };
4504 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4505 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4506 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4507 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4508 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4509 };
4510
4511 unsigned hasRelu = op.getRelu() ? 1 : 0;
4512 unsigned hasSatFinite =
4513 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4514 // idx: bit-0 - relu
4515 // bit-1 - satfinite
4516 unsigned idx = (hasSatFinite << 1) | hasRelu;
4517
4519 args.push_back(mt.lookupValue(op.getSrcHi()));
4520 args.push_back(mt.lookupValue(op.getSrcLo()));
4521 if (op.getRandomBits())
4522 args.push_back(mt.lookupValue(op.getRandomBits()));
4523
4524 switch (op.getRnd()) {
4525 case FPRoundingMode::RN:
4526 return {rndRNIds[idx], std::move(args)};
4527 case FPRoundingMode::RZ:
4528 return {rndRZIds[idx], std::move(args)};
4529 case FPRoundingMode::RS:
4530 return {rndRSIds[idx], std::move(args)};
4531 default:
4532 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
4533 }
4534}
4535
4537ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
4539 llvm::IRBuilderBase &builder) {
4540 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4541 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
4542 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
4543 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
4544 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
4545 };
4546 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4547 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
4548 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
4549 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
4550 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
4551 };
4552 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4553 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
4554 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
4555 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
4556 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
4557 };
4558
4559 unsigned hasRelu = op.getRelu() ? 1 : 0;
4560 unsigned hasSatFinite =
4561 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4562 // idx: bit-0 - relu
4563 // bit-1 - satfinite
4564 unsigned idx = (hasSatFinite << 1) | hasRelu;
4565
4567 args.push_back(mt.lookupValue(op.getSrcHi()));
4568 args.push_back(mt.lookupValue(op.getSrcLo()));
4569 if (op.getRandomBits())
4570 args.push_back(mt.lookupValue(op.getRandomBits()));
4571
4572 switch (op.getRnd()) {
4573 case FPRoundingMode::RN:
4574 return {rndRNIds[idx], std::move(args)};
4575 case FPRoundingMode::RZ:
4576 return {rndRZIds[idx], std::move(args)};
4577 case FPRoundingMode::RS:
4578 return {rndRSIds[idx], std::move(args)};
4579 default:
4580 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
4581 }
4582}
4583
4584llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
4585 mlir::Type dstTy = getDstTy();
4586 bool hasRelu = getRelu();
4587
4589 .Case([&](mlir::Float8E4M3FNType) {
4590 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
4591 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
4592 })
4593 .Case([&](mlir::Float8E5M2Type) {
4594 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
4595 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
4596 })
4597 .Default([](mlir::Type) {
4598 llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
4599 return llvm::Intrinsic::not_intrinsic;
4600 });
4601}
4602
4603llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
4604 mlir::Type dstTy = getDstTy();
4605 bool hasRelu = getRelu();
4606
4608 .Case([&](mlir::Float6E2M3FNType) {
4609 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
4610 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
4611 })
4612 .Case([&](mlir::Float6E3M2FNType) {
4613 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
4614 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
4615 })
4616 .Default([](mlir::Type) {
4617 llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
4618 return llvm::Intrinsic::not_intrinsic;
4619 });
4620}
4621
4622llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
4623 mlir::Type dstTy = getDstTy();
4624 bool hasRelu = getRelu();
4625
4627 .Case([&](mlir::Float4E2M1FNType) {
4628 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
4629 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
4630 })
4631 .Default([](mlir::Type) {
4632 llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
4633 return llvm::Intrinsic::not_intrinsic;
4634 });
4635}
4636
4637llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
4638 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
4639 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
4640 auto srcFmt = curOp.getSrcFormat();
4641 auto mc = curOp.getMulticast();
4642
4643 switch (curOp.getShape()) {
4644 case Tcgen05CpShape::SHAPE_128x256b:
4645 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
4646 case Tcgen05CpShape::SHAPE_128x128b:
4647 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
4648 case Tcgen05CpShape::SHAPE_4x256b:
4649 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
4650 case Tcgen05CpShape::SHAPE_32x128b:
4651 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
4652 case Tcgen05CpShape::SHAPE_64x128b:
4653 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
4654 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
4655 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
4656 }
4657 llvm_unreachable("Invalid shape in tcgen05 cp Op");
4658}
4659
4660// Returns the valid vector length for a given shape and vector length, the
4661// function models the table mentioned in the tcgen05.{ld, st} Op description
4662static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
4663 unsigned vecLen) {
4664 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
4665 return vecLen >= 2;
4666 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
4667 return vecLen >= 4;
4668 return true;
4669}
4670
4671LogicalResult Tcgen05LdOp::verify() {
4672 LogicalResult result = success();
4673 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4674 result = emitError("shape 16x32bx2 requires offset argument");
4675
4676 if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
4677 result = emitError("offset argument is only supported for shape 16x32bx2");
4678
4679 auto resTy = getRes().getType();
4680 unsigned resLen = isa<VectorType>(resTy)
4681 ? llvm::cast<VectorType>(resTy).getNumElements()
4682 : 1;
4683 if (!isValidVectorLength(getShape(), resLen))
4684 result = emitError(llvm::formatv("invalid result type length {0} for shape "
4685 "{1} in tcgen05.ld Op",
4686 resLen, stringifyEnum(getShape())));
4687
4688 return result;
4689}
4690
4691LogicalResult Tcgen05StOp::verify() {
4692 LogicalResult result = success();
4693 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4694 result = emitError("shape 16x32bx2 requires offset argument");
4695
4696 auto valTy = getVal().getType();
4697 unsigned valLen = isa<VectorType>(valTy)
4698 ? llvm::cast<VectorType>(valTy).getNumElements()
4699 : 1;
4700 if (!isValidVectorLength(getShape(), valLen))
4701 result = emitError(llvm::formatv("invalid input length {0} for shape "
4702 "{1} in tcgen05.st Op",
4703 valLen, stringifyEnum(getShape())));
4704
4705 return result;
4706}
4707
4708/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
4709/// have ConstantRangeAttr.
4712 SetIntRangeFn setResultRanges) {
4713 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
4714 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
4715 rangeAttr.getLower(), rangeAttr.getUpper()});
4716 } else {
4717 setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
4718 }
4719}
4720
4721/// Verify the range attribute satisfies LLVM ConstantRange constructor
4722/// requirements for NVVM SpecialRangeableRegisterOp.
4723static LogicalResult
4725 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
4726 if (!rangeAttr)
4727 return success();
4728
4729 const llvm::APInt &lower = rangeAttr->getLower();
4730 const llvm::APInt &upper = rangeAttr->getUpper();
4731
4732 // Check LLVM ConstantRange constructor condition
4733 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
4734 unsigned bitWidth = lower.getBitWidth();
4735 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
4736 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
4737 return op->emitOpError(
4738 "invalid range attribute: Lower == Upper, but they aren't min (")
4739 << llvm::toString(minVal, 10, false) << ") or max ("
4740 << llvm::toString(maxVal, 10, false)
4741 << ") value! This is an invalid constant range.";
4742 }
4743
4744 return success();
4745}
4746
4747static llvm::Value *getAsPackedI32(llvm::Value *arg,
4748 llvm::IRBuilderBase &builder) {
4749 return builder.CreateBitCast(arg,
4750 llvm::Type::getInt32Ty(builder.getContext()));
4751}
4752
4753NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
4754 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4755 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
4756
4758 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
4759 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
4760 args.push_back(mt.lookupValue(curOp.getC()));
4761
4762 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4763 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4764 unsigned type = (isASigned << 1) | isBSigned;
4765 const llvm::Intrinsic::ID ids[] = {
4766 llvm::Intrinsic::nvvm_idp4a_u_u,
4767 llvm::Intrinsic::nvvm_idp4a_u_s,
4768 llvm::Intrinsic::nvvm_idp4a_s_u,
4769 llvm::Intrinsic::nvvm_idp4a_s_s,
4770 };
4771 return {ids[type], args};
4772}
4773
4774NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
4775 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4776 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
4777
4779 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
4780 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
4781 args.push_back(builder.getInt1(curOp.getBHi()));
4782 args.push_back(mt.lookupValue(curOp.getC()));
4783
4784 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4785 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4786 unsigned type = (isASigned << 1) | isBSigned;
4787 const llvm::Intrinsic::ID ids[] = {
4788 llvm::Intrinsic::nvvm_idp2a_u_u,
4789 llvm::Intrinsic::nvvm_idp2a_u_s,
4790 llvm::Intrinsic::nvvm_idp2a_s_u,
4791 llvm::Intrinsic::nvvm_idp2a_s_s,
4792 };
4793 return {ids[type], args};
4794}
4795
4796static llvm::Value *getParamCastedAddr(llvm::Value *addr,
4797 llvm::IRBuilderBase &builder) {
4798 return builder.CreateAddrSpaceCast(
4799 addr,
4800 llvm::PointerType::get(builder.getContext(),
4801 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
4802}
4803
4805PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
4807 llvm::IRBuilderBase &builder) {
4808 using MemSpace = NVVM::NVVMMemorySpace;
4809 using CacheLevel = NVVM::PrefetchCacheLevel;
4810
4811 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
4812 std::optional<NVVM::CacheEvictionPriority> evictPriority =
4813 op.getEvictPriority();
4814 unsigned addressSpace =
4815 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
4816 .getAddressSpace();
4817
4819 llvm::Value *addr = mt.lookupValue(op.getAddr());
4820 args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
4821 : addr);
4822
4823 if (op.getTensormap())
4824 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
4825
4826 assert(cacheLevel && "expected cache level for non-tensormap prefetch");
4827
4828 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
4829 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
4830
4831 if (evictPriority && *cacheLevel == CacheLevel::L2) {
4832 switch (*evictPriority) {
4833 case NVVM::CacheEvictionPriority::EvictLast:
4834 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
4835 case NVVM::CacheEvictionPriority::EvictNormal:
4836 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
4837 default:
4838 llvm_unreachable("Invalid cache eviction priority");
4839 }
4840 }
4841
4842 switch (static_cast<MemSpace>(addressSpace)) {
4843 case MemSpace::Generic:
4844 return *cacheLevel == CacheLevel::L1
4845 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
4846 : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
4847 case MemSpace::Global:
4848 return *cacheLevel == CacheLevel::L1
4850 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
4851 : NVVM::IDArgPair(
4852 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
4853 case MemSpace::Local:
4854 return *cacheLevel == CacheLevel::L1
4856 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
4857 : NVVM::IDArgPair(
4858 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
4859 default:
4860 llvm_unreachable("Invalid pointer address space");
4861 }
4862}
4863
4864bool NVVM::InlinePtxOp::getAsmValues(
4865 RewriterBase &rewriter,
4866 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
4867 &asmValues) {
4868 for (auto arg : getReadWriteArgs())
4869 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
4870 for (auto arg : getResults())
4871 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
4872 for (auto arg : getReadOnlyArgs())
4873 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
4874 if (getPredicate())
4875 asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
4876 return false; // No manual mapping needed
4877}
4878
4879NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
4880 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4881 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
4883 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
4884 args.push_back(mt.lookupValue(curOp.getMbarrier()));
4885
4886 llvm::Intrinsic::ID intrinsicID =
4887 curOp.getMulticast()
4888 ? llvm::Intrinsic::
4889 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
4890 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
4891
4892 return {intrinsicID, args};
4893}
4894
4895NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
4896 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4897 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
4899 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
4900
4901 llvm::Intrinsic::ID intrinsicID;
4902
4903 switch (curOp.getQueryType()) {
4904 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
4905 intrinsicID =
4906 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
4907 break;
4908 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
4909 intrinsicID = llvm::Intrinsic::
4910 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
4911 break;
4912 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
4913 intrinsicID = llvm::Intrinsic::
4914 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
4915 break;
4916 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
4917 intrinsicID = llvm::Intrinsic::
4918 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
4919 break;
4920 }
4921 return {intrinsicID, args};
4922}
4923
4925PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4926 llvm::IRBuilderBase &builder) {
4927 auto thisOp = cast<NVVM::PermuteOp>(op);
4928 NVVM::PermuteMode mode = thisOp.getMode();
4929
4930 static constexpr llvm::Intrinsic::ID IDs[] = {
4931 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
4932 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
4933 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
4934 llvm::Intrinsic::nvvm_prmt_rc16};
4935
4936 unsigned modeIndex = static_cast<unsigned>(mode);
4938 args.push_back(mt.lookupValue(thisOp.getLo()));
4939
4940 // Only first 3 modes (Default, f4e, b4e) need the hi operand.
4941 if (modeIndex < 3)
4942 args.push_back(mt.lookupValue(thisOp.getHi()));
4943
4944 args.push_back(mt.lookupValue(thisOp.getSelector()));
4945
4946 return {IDs[modeIndex], args};
4947}
4948
4949mlir::NVVM::IDArgPair TensormapReplaceOp::getIntrinsicIDAndArgs(
4950 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4951 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
4952
4954 args.push_back(mt.lookupValue(thisOp.getAddr()));
4955 if (thisOp.getOrd())
4956 args.push_back(builder.getInt32(thisOp.getOrd().value()));
4957 if (thisOp.getNewValue())
4958 args.push_back(mt.lookupValue(thisOp.getNewValue()));
4959 if (auto attr = thisOp.getNewValueAttr()) {
4960 auto val =
4962 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
4963 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
4964 TensormapFillModeAttr>([](auto attr) {
4965 return static_cast<unsigned>(attr.getValue());
4966 })
4967 .Default([](auto attr) {
4968 llvm_unreachable("Invalid attribute type");
4969 return 0;
4970 });
4971 args.push_back(builder.getInt32(val));
4972 }
4973
4974 static constexpr llvm::Intrinsic::ID IDs[] = {
4975 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
4976 llvm::Intrinsic::nvvm_tensormap_replace_rank,
4977 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
4978 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
4979 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
4980 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
4981 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
4982 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
4983 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
4984 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
4985 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
4986 };
4987
4988 unsigned fieldIndex = static_cast<unsigned>(thisOp.getField());
4989
4990 return {IDs[fieldIndex], args};
4991}
4992
4993//===----------------------------------------------------------------------===//
4994// NVVM tcgen05.mma functions
4995//===----------------------------------------------------------------------===//
4996
4998Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4999 llvm::IRBuilderBase &builder) {
5000
5001 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
5003
5004 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5005
5006 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5007 const bool isATensor = isa<llvm::PointerType>(A->getType());
5008 args.push_back(A);
5009
5010 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5011 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5012 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5013
5014 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5015 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5016 using IsATensorArray = std::array<CtaGroupArray, 2>;
5017 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5018 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5019
5020 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5021 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
5022 { // without diable output lane
5023 {{// without scale input D
5024 {{
5025 // shared
5026 {{// cg1
5027 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
5028 // cg2
5029 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
5030 {{// tensor
5031 {
5032 // cg1
5033 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5034 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5035 },
5036 {
5037 // cg2
5038 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5039 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5040 }}},
5041 }},
5042 // with scale input D
5043 {{ // shared
5044 {{// cg1
5045 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
5046 // cg2
5047 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
5048 {{// tensor
5049 {
5050 // cg1
5051 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5052 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5053 },
5054 {
5055 // cg2
5056 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5057 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5058 }}}}}}},
5059 // with disable output lane
5060 {{ // without scale input D
5061 {{ // shared
5062 {{// cg1
5063 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
5064 notIntrinsic},
5065 // cg2
5066 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
5067 notIntrinsic}}},
5068 {{// cg1
5069 {
5070 llvm::Intrinsic::
5071 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
5072 llvm::Intrinsic::
5073 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
5074 },
5075 // cg2
5076 {
5077 llvm::Intrinsic::
5078 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
5079 llvm::Intrinsic::
5080 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
5081 }}}}},
5082 // with scale input D
5083 {{ // shared
5084 {{// cg1
5085 {llvm::Intrinsic::
5086 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
5087 notIntrinsic},
5088 // cg2
5089 {llvm::Intrinsic::
5090 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
5091 notIntrinsic}}},
5092 // tensor
5093 {{// cg1
5094 {llvm::Intrinsic::
5095 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
5096 llvm::Intrinsic::
5097 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
5098 // cg2
5099 {
5100 llvm::Intrinsic::
5101 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5102 llvm::Intrinsic::
5103 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5104 }}}}}}}}};
5105
5106 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5107 bool hasScaleInputD = ScaleInputD != nullptr;
5108
5109 llvm::Value *DisableOutputLane =
5110 mt.lookupValue(thisOp.getDisableOutputLane());
5111 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5112
5113 const unsigned ctaGroup =
5114 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5115
5116 llvm::Intrinsic::ID ID =
5117 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5118 [ctaGroup - 1][thisOp.getAShift()];
5119
5120 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
5121
5122 if (hasScaleInputD)
5123 args.push_back(ScaleInputD);
5124
5125 if (hasDisableOutputLane)
5126 args.push_back(DisableOutputLane);
5127
5128 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5129
5130 if (!hasDisableOutputLane)
5131 args.push_back(builder.getInt32(ctaGroup));
5132
5133 args.push_back(
5134 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5135
5136 return {ID, args};
5137}
5138
5139static LogicalResult
5140verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
5141 NVVM::CTAGroupKind ctaGroup, bool hasAShift,
5142 NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
5143
5144 if (disableOutputLane) {
5145 mlir::VectorType disableOutputLaneType =
5146 cast<mlir::VectorType>(disableOutputLane.getType());
5147 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5148 disableOutputLaneType.getNumElements() != 4) ||
5149 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5150 disableOutputLaneType.getNumElements() != 8))
5151 return emitError(loc) << "Disable Output Lane of length "
5152 << disableOutputLaneType.getNumElements()
5153 << " is incompatible with CtaGroupAttr";
5154 }
5155
5156 if (hasAShift && !isATensor)
5157 return emitError(
5158 loc, "A-shift can be applied only when matrix A is in tensor memory");
5159
5160 if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5161 collectorOp == Tcgen05MMACollectorOp::USE))
5162 return emitError(
5163 loc, "Cannot use collector buffer operation fill or use with ashift");
5164
5165 return success();
5166}
5167
5168LogicalResult Tcgen05MMAOp::verify() {
5169 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5170 getDisableOutputLane(), getCtaGroup(), getAShift(),
5171 getCollectorOp(), getLoc());
5172}
5173
5174//===----------------------------------------------------------------------===//
5175// NVVM tcgen05.mma.sp functions
5176//===----------------------------------------------------------------------===//
5177
5178mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
5179 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5180
5181 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5183
5184 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5185
5186 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5187 bool isATensor = isa<llvm::PointerType>(A->getType());
5188 args.push_back(A);
5189
5190 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5191 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5192 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5193 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5194
5195 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5196 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5197 using IsATensorArray = std::array<CtaGroupArray, 2>;
5198 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5199 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5200
5201 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5202 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5203 { // without diable output lane
5204 {{// without scale input D
5205 {{
5206 // shared
5207 {{// cg1
5208 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
5209 // cg2
5210 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
5211 {{// tensor
5212 {
5213 // cg1
5214 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5215 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5216 },
5217 {
5218 // cg2
5219 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5220 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5221 }}},
5222 }},
5223 // with scale input D
5224 {{ // shared
5225 {{// cg1
5226 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5227 notIntrinsic},
5228 // cg2
5229 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5230 notIntrinsic}}},
5231 {{// tensor
5232 {
5233 // cg1
5234 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5235 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5236 },
5237 {
5238 // cg2
5239 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5240 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5241 }}}}}}},
5242 // with disable output lane
5243 {{ // without scale input D
5244 {{ // shared
5245 {{// cg1
5246 {llvm::Intrinsic::
5247 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5248 notIntrinsic},
5249 // cg2
5250 {llvm::Intrinsic::
5251 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5252 notIntrinsic}}},
5253 {{// cg1
5254 {
5255 llvm::Intrinsic::
5256 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5257 llvm::Intrinsic::
5258 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5259 },
5260 // cg2
5261 {
5262 llvm::Intrinsic::
5263 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5264 llvm::Intrinsic::
5265 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5266 }}}}},
5267 // with scale input D
5268 {{ // shared
5269 {{// cg1
5270 {llvm::Intrinsic::
5271 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5272 notIntrinsic},
5273 // cg2
5274 {llvm::Intrinsic::
5275 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5276 notIntrinsic}}},
5277 // tensor
5278 {{// cg1
5279 {llvm::Intrinsic::
5280 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5281 llvm::Intrinsic::
5282 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5283 // cg2
5284 {
5285 llvm::Intrinsic::
5286 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5287 llvm::Intrinsic::
5288 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5289 }}}}}}}}};
5290
5291 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5292 bool hasScaleInputD = ScaleInputD != nullptr;
5293
5294 llvm::Value *DisableOutputLane =
5295 mt.lookupValue(thisOp.getDisableOutputLane());
5296 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5297
5298 unsigned ctaGroup =
5299 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5300
5301 llvm::Intrinsic::ID ID =
5302 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5303 [ctaGroup - 1][thisOp.getAShift()];
5304
5305 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
5306
5307 if (hasScaleInputD)
5308 args.push_back(ScaleInputD);
5309
5310 if (hasDisableOutputLane)
5311 args.push_back(DisableOutputLane);
5312
5313 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5314
5315 if (!hasDisableOutputLane)
5316 args.push_back(builder.getInt32(ctaGroup));
5317
5318 args.push_back(
5319 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5320
5321 return {ID, args};
5322}
5323
5324LogicalResult Tcgen05MMASparseOp::verify() {
5325 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5326 getDisableOutputLane(), getCtaGroup(), getAShift(),
5327 getCollectorOp(), getLoc());
5328}
5329
5330//===----------------------------------------------------------------------===//
5331// NVVM tcgen05.mma.block_scale functions
5332//===----------------------------------------------------------------------===//
5333
5334mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
5335 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5336
5337 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5339
5340 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5341
5342 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5343 bool isATensor = isa<llvm::PointerType>(A->getType());
5344 args.push_back(A);
5345
5346 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5347 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5348 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5349 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5350 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5351 args.push_back(builder.getInt32(
5352 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5353 args.push_back(
5354 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5355
5356 auto kind = thisOp.getKind();
5357 auto blockScale = thisOp.getBlockScale();
5358 llvm::Intrinsic::ID ID = [&]() {
5359 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5360 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5361 return isATensor ? llvm::Intrinsic::
5362 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5363 : llvm::Intrinsic::
5364 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5365 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5366 return isATensor
5367 ? llvm::Intrinsic::
5368 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5369 : llvm::Intrinsic::
5370 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5371 }
5372 } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5373 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5374 return isATensor
5375 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5376 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5377 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5378 return isATensor ? llvm::Intrinsic::
5379 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5380 : llvm::Intrinsic::
5381 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5382 }
5383 } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5384 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5385 return isATensor
5386 ? llvm::Intrinsic::
5387 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5388 : llvm::Intrinsic::
5389 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5390
5391 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5392 return isATensor
5393 ? llvm::Intrinsic::
5394 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5395 : llvm::Intrinsic::
5396 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5397 }
5398 }
5399 llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
5400 }();
5401
5402 return {ID, args};
5403}
5404
5405static LogicalResult verifyTcgen05MMABlockScaleOp(
5406 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
5407 NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
5408 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5409 kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
5410 return emitError(loc, "mxf4nvf4 requires block scale attribute");
5411
5412 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5413 kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
5414 return emitError(loc,
5415 llvm::formatv("{} kind does not support block16 attribute",
5416 stringifyEnum(kind)));
5417
5418 return success();
5419}
5420
5421LogicalResult Tcgen05MMABlockScaleOp::verify() {
5422 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5423 getBlockScale(), getLoc());
5424}
5425
5426//===----------------------------------------------------------------------===//
5427// NVVM tcgen05.mma.sp.block_scale functions
5428//===----------------------------------------------------------------------===//
5429
5430mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
5431 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5432
5433 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5435
5436 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5437
5438 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5439 bool isATensor = isa<llvm::PointerType>(A->getType());
5440 args.push_back(A);
5441
5442 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5443 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5444 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5445 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5446 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5447 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5448 args.push_back(builder.getInt32(
5449 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5450 args.push_back(
5451 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5452
5453 auto kind = thisOp.getKind();
5454 auto blockScale = thisOp.getBlockScale();
5455 llvm::Intrinsic::ID ID = [&]() {
5456 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5457 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5458 return isATensor ? llvm::Intrinsic::
5459 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5460 : llvm::Intrinsic::
5461 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5462 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5463 return isATensor
5464 ? llvm::Intrinsic::
5465 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5466 : llvm::Intrinsic::
5467 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5468 }
5469 } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5470 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5471 return isATensor ? llvm::Intrinsic::
5472 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5473 : llvm::Intrinsic::
5474 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5475 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5476 return isATensor
5477 ? llvm::Intrinsic::
5478 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5479 : llvm::Intrinsic::
5480 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5481 }
5482 } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5483 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5484 return isATensor
5485 ? llvm::Intrinsic::
5486 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5487 : llvm::Intrinsic::
5488 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5489
5490 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5491 return isATensor
5492 ? llvm::Intrinsic::
5493 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5494 : llvm::Intrinsic::
5495 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5496 }
5497 }
5498 llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
5499 }();
5500
5501 return {ID, args};
5502}
5503
5504LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5505 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5506 getBlockScale(), getLoc());
5507}
5508
5509//===----------------------------------------------------------------------===//
5510// NVVM tcgen05.mma.ws functions
5511//===----------------------------------------------------------------------===//
5512
5513mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
5514 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5515
5516 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5518
5519 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5520
5521 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5522 bool isATensor = isa<llvm::PointerType>(A->getType());
5523 args.push_back(A);
5524
5525 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5526 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5527 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5528
5529 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5530 llvm::Intrinsic::ID ID = notIntrinsic;
5531 if (ZeroColMask) {
5532 args.push_back(mt.lookupValue(ZeroColMask));
5533 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
5534 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
5535 } else
5536 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
5537 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
5538
5539 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5540 args.push_back(
5541 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5542 args.push_back(
5543 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5544
5545 return {ID, args};
5546}
5547
5548//===----------------------------------------------------------------------===//
5549// NVVM tcgen05.mma.ws.sp functions
5550//===----------------------------------------------------------------------===//
5551
5552mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
5553 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5554
5555 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
5557
5558 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5559
5560 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5561 bool isATensor = isa<llvm::PointerType>(A->getType());
5562 args.push_back(A);
5563
5564 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5565 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5566 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5567 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5568
5569 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5570 llvm::Intrinsic::ID ID = notIntrinsic;
5571 if (ZeroColMask) {
5572 args.push_back(mt.lookupValue(ZeroColMask));
5573 ID = isATensor
5574 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
5575 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
5576 } else
5577 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
5578 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
5579
5580 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5581 args.push_back(
5582 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5583 args.push_back(
5584 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5585
5586 return {ID, args};
5587}
5588
5589//===----------------------------------------------------------------------===//
5590// NVVM tcgen05.ld.red functions
5591//===----------------------------------------------------------------------===//
5592
5593#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
5594 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
5595
5596mlir::NVVM::IDArgPair NVVM::Tcgen05LdRedOp::getIntrinsicIDAndArgs(
5597 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5598 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
5600
5601 mlir::VectorType VecResTy =
5602 cast<mlir::VectorType>(thisOp.getData().getType());
5603 unsigned Num = VecResTy.getNumElements();
5604 bool IsFloat = thisOp.getRedVal().getType().isF32();
5605
5606 llvm::Intrinsic::ID Shape32x32b[][2] = {
5608 {TCGEN05LDRED(32x32b, x2, i32), TCGEN05LDRED(32x32b, x2, f32)},
5609 {TCGEN05LDRED(32x32b, x4, i32), TCGEN05LDRED(32x32b, x4, f32)},
5610 {TCGEN05LDRED(32x32b, x8, i32), TCGEN05LDRED(32x32b, x8, f32)},
5611 {TCGEN05LDRED(32x32b, x16, i32), TCGEN05LDRED(32x32b, x16, f32)},
5612 {TCGEN05LDRED(32x32b, x32, i32), TCGEN05LDRED(32x32b, x32, f32)},
5613 {TCGEN05LDRED(32x32b, x64, i32), TCGEN05LDRED(32x32b, x64, f32)},
5614 {TCGEN05LDRED(32x32b, x128, i32), TCGEN05LDRED(32x32b, x128, f32)},
5615 };
5616
5617 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
5619 {TCGEN05LDRED(16x32bx2, x2, i32), TCGEN05LDRED(16x32bx2, x2, f32)},
5620 {TCGEN05LDRED(16x32bx2, x4, i32), TCGEN05LDRED(16x32bx2, x4, f32)},
5621 {TCGEN05LDRED(16x32bx2, x8, i32), TCGEN05LDRED(16x32bx2, x8, f32)},
5622 {TCGEN05LDRED(16x32bx2, x16, i32), TCGEN05LDRED(16x32bx2, x16, f32)},
5623 {TCGEN05LDRED(16x32bx2, x32, i32), TCGEN05LDRED(16x32bx2, x32, f32)},
5624 {TCGEN05LDRED(16x32bx2, x64, i32), TCGEN05LDRED(16x32bx2, x64, f32)},
5625 {TCGEN05LDRED(16x32bx2, x128, i32), TCGEN05LDRED(16x32bx2, x128, f32)},
5626 };
5627
5628 NVVM::Tcgen05LdStShape shape = thisOp.getShape();
5629 unsigned ID = [&]() {
5630 // `num` contains the length of vector and log2 of `num` returns the index
5631 // into the shape array
5632 unsigned idx = std::log2(Num);
5633 switch (shape) {
5634 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
5635 return Shape32x32b[idx][IsFloat];
5636 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
5637 return Shape16x32bx2[idx][IsFloat];
5638 default:
5639 llvm_unreachable("unhandled tcgen05.ld lowering");
5640 }
5641 }();
5642
5643 args.push_back(mt.lookupValue(thisOp.getAddr()));
5644
5645 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
5646 args.push_back(mt.lookupValue(thisOp.getOffset()));
5647
5648 args.push_back(
5649 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
5650
5651 if (IsFloat) {
5652 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getAbs())));
5653 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getNan())));
5654 }
5655 return {ID, args};
5656}
5657
5658LogicalResult Tcgen05LdRedOp::verify() {
5659 VectorType data = cast<VectorType>(getData().getType());
5660 Type redVal = getRedVal().getType();
5661
5662 if (data.getElementType() != redVal)
5663 return emitError(
5664 "type of reduction value and element type of vector data should match");
5665
5666 if (getOp() != NVVM::ReductionKind::MIN &&
5667 getOp() != NVVM::ReductionKind::MAX)
5668 return emitError("only min and max reduction kinds are supported");
5669
5670 if (redVal.isInteger() && (getAbs() || getNan())) {
5671 return emitError("abs or nan is only applicable for f32 type");
5672 }
5673 return success();
5674}
5675
5676//===----------------------------------------------------------------------===//
5677// NVVMDialect initialization, type parsing, and registration.
5678//===----------------------------------------------------------------------===//
5679
5680// TODO: This should be the llvm.nvvm dialect once this is supported.
5681void NVVMDialect::initialize() {
5682 addOperations<
5683#define GET_OP_LIST
5684#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5685 >();
5686 addAttributes<
5687#define GET_ATTRDEF_LIST
5688#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
5689 >();
5690
5691 // Support unknown operations because not all NVVM operations are
5692 // registered.
5693 allowUnknownOperations();
5694 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
5695 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
5696}
5697
5698LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
5699 NamedAttribute attr) {
5700 StringAttr attrName = attr.getName();
5701 // Kernel function attribute should be attached to functions.
5702 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
5703 if (!isa<LLVM::LLVMFuncOp>(op)) {
5704 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
5705 << "' attribute attached to unexpected op";
5706 }
5707 }
5708 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
5709 // dim
5710 if (attrName == NVVMDialect::getMaxntidAttrName() ||
5711 attrName == NVVMDialect::getReqntidAttrName() ||
5712 attrName == NVVMDialect::getClusterDimAttrName()) {
5713 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
5714 if (!values || values.empty() || values.size() > 3) {
5715 return op->emitError()
5716 << "'" << attrName
5717 << "' attribute must be integer array with maximum 3 index";
5718 }
5719 }
5720 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
5721 // attribute
5722 if (attrName == NVVMDialect::getMinctasmAttrName() ||
5723 attrName == NVVMDialect::getMaxnregAttrName() ||
5724 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
5725 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
5726 return op->emitError()
5727 << "'" << attrName << "' attribute must be integer constant";
5728 }
5729 }
5730 // blocksareclusters must be used along with reqntid and cluster_dim
5731 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
5732 if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
5733 !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
5734 return op->emitError()
5735 << "'" << attrName << "' attribute must be used along with "
5736 << "'" << NVVMDialect::getReqntidAttrName() << "' and "
5737 << "'" << NVVMDialect::getClusterDimAttrName() << "'";
5738 }
5739 }
5740
5741 return success();
5742}
5743
5744LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
5745 unsigned regionIndex,
5746 unsigned argIndex,
5747 NamedAttribute argAttr) {
5748 auto funcOp = dyn_cast<FunctionOpInterface>(op);
5749 if (!funcOp)
5750 return success();
5751
5752 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
5753 StringAttr attrName = argAttr.getName();
5754 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
5755 if (!isKernel) {
5756 return op->emitError()
5757 << "'" << attrName
5758 << "' attribute must be present only on kernel arguments";
5759 }
5760 if (!isa<UnitAttr>(argAttr.getValue()))
5761 return op->emitError() << "'" << attrName << "' must be a unit attribute";
5762 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
5763 return op->emitError()
5764 << "'" << attrName
5765 << "' attribute requires the argument to also have attribute '"
5766 << LLVM::LLVMDialect::getByValAttrName() << "'";
5767 }
5768 }
5769
5770 return success();
5771}
5772
5773//===----------------------------------------------------------------------===//
5774// NVVM Address Space Attr
5775//===----------------------------------------------------------------------===//
5776
5777unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
5778 return static_cast<unsigned>(getValue());
5779}
5780
5781bool NVVMMemorySpaceAttr::isValidLoad(
5782 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5783 const ::mlir::DataLayout *dataLayout,
5785 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
5786 dataLayout, emitError);
5787}
5788
5789bool NVVMMemorySpaceAttr::isValidStore(
5790 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5791 const ::mlir::DataLayout *dataLayout,
5793 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
5794 dataLayout, emitError);
5795}
5796
5797bool NVVMMemorySpaceAttr::isValidAtomicOp(
5798 ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
5799 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
5801 // TODO: update this method once `ptr.atomic_rmw` is implemented.
5802 assert(false && "unimplemented, see TODO in the source.");
5803 return false;
5804}
5805
5806bool NVVMMemorySpaceAttr::isValidAtomicXchg(
5807 Type type, ptr::AtomicOrdering successOrdering,
5808 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
5809 const ::mlir::DataLayout *dataLayout,
5811 // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
5812 assert(false && "unimplemented, see TODO in the source.");
5813 return false;
5814}
5815
5816bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
5817 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
5818 // TODO: update this method once the `ptr.addrspace_cast` op is added to the
5819 // dialect.
5820 assert(false && "unimplemented, see TODO in the source.");
5821 return false;
5822}
5823
5824bool NVVMMemorySpaceAttr::isValidPtrIntCast(
5825 Type intLikeTy, Type ptrLikeTy,
5827 // TODO: update this method once the int-cast ops are added to the `ptr`
5828 // dialect.
5829 assert(false && "unimplemented, see TODO in the source.");
5830 return false;
5831}
5832
5833//===----------------------------------------------------------------------===//
5834// NVVM target attribute.
5835//===----------------------------------------------------------------------===//
5836LogicalResult
5837NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
5838 int optLevel, StringRef triple, StringRef chip,
5839 StringRef features, DictionaryAttr flags,
5840 ArrayAttr files, bool verifyTarget) {
5841 if (optLevel < 0 || optLevel > 3) {
5842 emitError() << "The optimization level must be a number between 0 and 3.";
5843 return failure();
5844 }
5845 if (triple.empty()) {
5846 emitError() << "The target triple cannot be empty.";
5847 return failure();
5848 }
5849 if (chip.empty()) {
5850 emitError() << "The target chip cannot be empty.";
5851 return failure();
5852 }
5853 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
5854 return mlir::isa_and_nonnull<StringAttr>(attr);
5855 })) {
5856 emitError() << "All the elements in the `link` array must be strings.";
5857 return failure();
5858 }
5859 return success();
5860}
5861
5862LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
5863 if (!getVerifyTarget())
5864 return success();
5865
5866 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
5867 if (!gpuModuleOp) {
5868 return emitError(gpuModule->getLoc(),
5869 "NVVM target attribute must be attached to a GPU module");
5870 }
5871
5872 const NVVMCheckSMVersion targetSMVersion =
5874 if (!targetSMVersion.isMinimumSMVersion()) {
5875 return emitError(gpuModule->getLoc(),
5876 "Minimum NVVM target SM version is sm_20");
5877 }
5878
5879 if (gpuModuleOp
5880 ->walk([&](Operation *op) {
5881 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
5882 const NVVMCheckSMVersion requirement =
5883 reqOp.getRequiredMinSMVersion();
5884 if (!requirement.isCompatibleWith(targetSMVersion)) {
5885 op->emitOpError() << "is not supported on " << getChip();
5886 return WalkResult::interrupt();
5887 }
5888 }
5889 return WalkResult::advance();
5890 })
5891 .wasInterrupted())
5892 return failure();
5893
5894 return success();
5895}
5896
5897#define GET_OP_CLASSES
5898#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5899
5900#define GET_ATTRDEF_CLASSES
5901#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
#define _none
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyAddSubFOp(OpType op)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &regs)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > &regTypes, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerType getI16Type()
Definition Builders.cpp:65
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:579
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:589
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition NVVMDialect.h:53
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
LogicalResult matchAndRewrite(SubFOp op, PatternRewriter &rewriter) const override
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.