MLIR 23.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
39#include <cassert>
40#include <optional>
41#include <string>
42
43using namespace mlir;
44using namespace NVVM;
45
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
50
51//===----------------------------------------------------------------------===//
52// Helper/Utility methods
53//===----------------------------------------------------------------------===//
54
55static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
57 return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
58}
59
61 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
62}
63
65 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
66}
67
69 return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
70}
71
72static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
73 llvm::Value *ptr,
74 NVVMMemorySpace targetAS) {
75 unsigned AS = static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
78}
79
80// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
81static llvm::nvvm::CTAGroupKind
82getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
83 switch (ctaGroup) {
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
88 }
89 llvm_unreachable("unsupported cta_group value");
90}
91
92//===----------------------------------------------------------------------===//
93// Verifier methods
94//===----------------------------------------------------------------------===//
95
96// This verifier is shared among the following Ops:
97// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
98// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
99static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
100 bool isIm2Col,
101 size_t numIm2ColOffsets,
102 Location loc) {
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc, "expects coordinates between 1 to 5 dimension");
105
106 // For Im2Col mode, there are two constraints:
107 if (isIm2Col) {
108 // 1. Tensor must always be at least 3-d.
109 if (tensorDims < 3)
110 return emitError(
111 loc,
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
113 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
115 return emitError(
116 loc, "im2col offsets must be 2 less than number of coordinates");
117 }
118 return success();
119}
120
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
123 // We lower through inline-ptx when getPredicate() is true.
124 // a) Only TILE mode is supported
125 // b) Cache-hint is not supported
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError("Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
131 }
132
133 size_t dims = getCoordinates().size();
134 switch (mode) {
135 case TMAStoreMode::TILE:
136 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
137 case TMAStoreMode::IM2COL:
138 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
139 case TMAStoreMode::TILE_SCATTER4:
140 if (dims != 5)
141 return emitError("Scatter4 mode expects 5 coordinates");
142 }
143 return success();
144}
145
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError("Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError("expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError("CG cache modifier is only support for 16 bytes copy.");
154 return success();
155}
156
157// This verify params can be shared across TMA Load and Prefetch Ops.
158static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
159 TMALoadMode mode, Location loc) {
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc, "expects coordinates between 1 to 5 dimension");
162
163 auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
166 return emitError(loc)
167 << "to use " << mode
168 << " mode, the tensor has to be at least 3-dimensional";
169
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
172 << " (provided " << numIm2colOff << ")";
173
174 return success();
175 };
176
177 switch (mode) {
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode, false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode, true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode, true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode, false, 0)
188 : emitError(loc, "Gather4 mode expects 5 coordinates");
189 }
190 return success();
191}
192
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
194 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
195 getMode(), getLoc());
196}
197
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) { // Inline-asm based lowering
202 if (isCTAOnly)
203 return emitError("Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
205 return emitError(
206 "Predicate is supported only for Tile and Im2col modes.");
207 } else { // Intrinsics-based lowering
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
211 .getAddressSpace();
212 if (AS != expectedAS)
213 return emitError()
214 << (isCTAOnly
215 ? "Shared::cta destination requires address-space 3."
216 : "Shared::cluster destination requires address-space 7.");
217 // Checks specific to shared::cta mode
218 if (isCTAOnly) {
219 if (getMulticastMask())
220 return emitError("Multicast is not supported with shared::cta mode.");
221 if (getGroup())
222 return emitError("CTAGroup is not supported with shared::cta mode.");
223 }
224 }
225
226 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
227 getMode(), getLoc());
228}
229
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
232 size_t dims = getCoordinates().size();
233 switch (mode) {
234 case TMAStoreMode::TILE:
235 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
236 case TMAStoreMode::IM2COL:
237 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
240 }
241 return success();
242}
243
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
245 bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
246 if (isSharedCTA && getMulticastMask())
247 return emitError("Multicast is not supported with shared::cta mode.");
248
249 return success();
250}
251
252static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
253 NVVM::MemScopeKind scope,
254 Value retVal = nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->emitError("mbarrier scope must be either CTA or Cluster");
257
258 bool isSharedCluster = isPtrInSharedClusterSpace(addr);
259 bool hasRetValue = static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
261 return op->emitError(
262 "mbarrier in shared_cluster space cannot return any value");
263
264 return success();
265}
266
267LogicalResult MBarrierArriveOp::verify() {
268 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
269 getRes());
270}
271
272LogicalResult MBarrierArriveDropOp::verify() {
273 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
274 getRes());
275}
276
277LogicalResult MBarrierArriveExpectTxOp::verify() {
278 // The inline-ptx version of this Op does not support all features.
279 // With predicate, this Op lowers to inline-ptx. So, verify and
280 // error-out if there are unsupported features.
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError("mbarrier scope must be CTA when using predicate");
284
285 if (isPtrInSharedClusterSpace(getAddr()))
286 return emitError("mbarrier in shared_cluster space is not supported when "
287 "using predicate");
288
289 if (getRes())
290 return emitError("return-value is not supported when using predicate");
291
292 if (getRelaxed() == true)
293 return emitError("mbarrier with relaxed semantics is not supported when "
294 "using predicate");
295 }
296 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
297 getRes());
298}
299
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
301 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
302 getRes());
303}
304
305//===----------------------------------------------------------------------===//
306// inferReturnTypes for mbarrier arrive-like ops
307//===----------------------------------------------------------------------===//
308
309/// Only shared_cluster (ptr<7>) produces zero results; all other address
310/// spaces (including generic) return i64.
311static LogicalResult
313 SmallVectorImpl<Type> &inferredReturnTypes) {
314 if (!isPtrInSharedClusterSpace(addr))
315 inferredReturnTypes.push_back(IntegerType::get(context, 64));
316 return success();
317}
318
319LogicalResult
320MBarrierArriveOp::inferReturnTypes(MLIRContext *context,
321 std::optional<Location> location,
322 MBarrierArriveOp::Adaptor adaptor,
323 SmallVectorImpl<Type> &inferredReturnTypes) {
324 return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
325 inferredReturnTypes);
326}
327
328LogicalResult MBarrierArriveDropOp::inferReturnTypes(
329 MLIRContext *context, std::optional<Location> location,
330 MBarrierArriveDropOp::Adaptor adaptor,
331 SmallVectorImpl<Type> &inferredReturnTypes) {
332 return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
333 inferredReturnTypes);
334}
335
336LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
337 MLIRContext *context, std::optional<Location> location,
338 MBarrierArriveExpectTxOp::Adaptor adaptor,
339 SmallVectorImpl<Type> &inferredReturnTypes) {
340 // Predicate forces no return value (inline PTX path).
341 // Note: predicate + shared_cluster is rejected by the verifier separately.
342 if (adaptor.getPredicate())
343 return success();
344 return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
345 inferredReturnTypes);
346}
347
348LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
349 MLIRContext *context, std::optional<Location> location,
350 MBarrierArriveDropExpectTxOp::Adaptor adaptor,
351 SmallVectorImpl<Type> &inferredReturnTypes) {
352 return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
353 inferredReturnTypes);
354}
355
356/// For ops with optional results, allow the user to omit the result even when
357/// inference would produce one. This preserves backward compatibility: the
358/// result can be silently discarded (e.g., for fire-and-forget arrive ops).
360 TypeRange actual) {
361 if (actual.empty())
362 return true;
363 return inferred == actual;
364}
365
366bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
368}
369bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
371}
372bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l,
373 TypeRange r) {
375}
376bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l,
377 TypeRange r) {
379}
380
381LogicalResult MBarrierExpectTxOp::verify() {
382 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
383}
384
385LogicalResult MBarrierCompleteTxOp::verify() {
386 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
387}
388
389LogicalResult MBarrierTestWaitOp::verify() {
390 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
391}
392
393LogicalResult MBarrierTryWaitOp::verify() {
394 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
395}
396
397LogicalResult ConvertFloatToTF32Op::verify() {
398 using RndMode = NVVM::FPRoundingMode;
399 switch (getRnd()) {
400 case RndMode::RNA:
401 if (getRelu())
402 return emitError("Relu not supported with rna rounding mode.");
403 break;
404 case RndMode::RN:
405 case RndMode::RZ:
406 break;
407 default:
408 return emitError(
409 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
410 }
411 return success();
412}
413
414LogicalResult ConvertF32x2ToF6x2Op::verify() {
416
417 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
418 return emitOpError("Only ")
419 << mlir::Float6E2M3FNType::get(ctx) << " and "
420 << mlir::Float6E3M2FNType::get(ctx)
421 << " types are supported for conversions from f32x2 to f6x2.";
422 }
423 return success();
424}
425
426LogicalResult ConvertF32x2ToF8x2Op::verify() {
427 using RndMode = NVVM::FPRoundingMode;
428 using SatMode = NVVM::SaturationMode;
429
430 bool isRoundingModeRN = getRnd() == RndMode::RN;
431 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
432 bool isRoundingModeRP = getRnd() == RndMode::RP;
433 bool isSatFinite = getSat() == SatMode::SATFINITE;
434
435 bool hasRelu = getRelu();
436
438
440 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
441 [&](mlir::Type) -> LogicalResult {
442 if (!isRoundingModeRN) {
443 return emitOpError("Only RN rounding mode is supported for "
444 "conversions from f32x2 to ")
445 << mlir::Float8E4M3FNType::get(ctx) << " and "
446 << mlir::Float8E5M2Type::get(ctx) << " types";
447 }
448 if (!isSatFinite) {
449 return emitOpError("Only SATFINITE saturation mode is supported "
450 "for conversions "
451 "from f32x2 to ")
452 << mlir::Float8E4M3FNType::get(ctx) << " and "
453 << mlir::Float8E5M2Type::get(ctx) << " types";
454 }
455 return success();
456 })
457 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
458 if (!(isRoundingModeRZ || isRoundingModeRP)) {
459 return emitOpError("Only RZ and RP rounding modes are supported for "
460 "conversions from f32x2 to ")
461 << mlir::Float8E8M0FNUType::get(ctx) << " type";
462 }
463 if (hasRelu) {
464 return emitOpError("relu not supported for conversions to ")
465 << mlir::Float8E8M0FNUType::get(ctx) << " type";
466 }
467 return success();
468 })
469 .Default([&](mlir::Type) {
470 return emitOpError("Only ")
471 << mlir::Float8E4M3FNType::get(ctx) << ", "
472 << mlir::Float8E5M2Type::get(ctx) << ", and "
473 << mlir::Float8E8M0FNUType::get(ctx)
474 << " types are "
475 "supported for conversions from f32x2 to f8x2";
476 });
477}
478
479LogicalResult ConvertF16x2ToF8x2Op::verify() {
481
482 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
483 return emitOpError("Only ")
484 << mlir::Float8E4M3FNType::get(ctx) << " and "
485 << mlir::Float8E5M2Type::get(ctx)
486 << " types are supported for conversions from f16x2 to f8x2.";
487 }
488 return success();
489}
490
491LogicalResult ConvertBF16x2ToF8x2Op::verify() {
492 using RndMode = NVVM::FPRoundingMode;
493 using SatMode = NVVM::SaturationMode;
494
495 bool isRoundingModeRN = getRnd() == RndMode::RN;
496 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
497 bool isRoundingModeRP = getRnd() == RndMode::RP;
498 bool isSatFinite = getSat() == SatMode::SATFINITE;
499 bool hasRelu = getRelu();
500
502
504 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
505 [&](mlir::Type) -> LogicalResult {
506 if (!isRoundingModeRN)
507 return emitOpError("Only RN rounding mode is supported for "
508 "conversions from bf16x2 to ")
509 << mlir::Float8E4M3FNType::get(ctx) << " and "
510 << mlir::Float8E5M2Type::get(ctx) << " types";
511 if (!isSatFinite)
512 return emitOpError("Only SATFINITE saturation mode is supported "
513 "for conversions from bf16x2 to ")
514 << mlir::Float8E4M3FNType::get(ctx) << " and "
515 << mlir::Float8E5M2Type::get(ctx) << " types";
516 return success();
517 })
518 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
519 if (!(isRoundingModeRZ || isRoundingModeRP))
520 return emitOpError("Only RZ and RP rounding modes are supported for "
521 "conversions from bf16x2 to ")
522 << mlir::Float8E8M0FNUType::get(ctx) << " type";
523 if (hasRelu)
524 return emitOpError("relu not supported for conversions to ")
525 << mlir::Float8E8M0FNUType::get(ctx) << " type";
526 return success();
527 })
528 .Default([&](mlir::Type) -> LogicalResult {
529 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
530 return failure();
531 });
532}
533
534LogicalResult ConvertF32x2ToF4x2Op::verify() {
536
537 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
538 return emitOpError("Only ")
539 << mlir::Float4E2M1FNType::get(ctx)
540 << " type is supported for conversions from f32x2 to f4x2.";
541
542 return success();
543}
544
545LogicalResult ConvertF8x2ToF16x2Op::verify() {
547
548 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
549 return emitOpError("Only ")
550 << mlir::Float8E4M3FNType::get(ctx) << " and "
551 << mlir::Float8E5M2Type::get(ctx)
552 << " types are supported for conversions from f8x2 to f16x2.";
553
554 return success();
555}
556
557LogicalResult ConvertF8x2ToBF16x2Op::verify() {
559 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
560 return emitOpError("Only ")
561 << mlir::Float8E8M0FNUType::get(ctx)
562 << " type is supported for conversions from f8x2 to bf16x2.";
563
564 return success();
565}
566
567LogicalResult ConvertF6x2ToF16x2Op::verify() {
569
570 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
571 return emitOpError("Only ")
572 << mlir::Float6E2M3FNType::get(ctx) << " and "
573 << mlir::Float6E3M2FNType::get(ctx)
574 << " types are supported for conversions from f6x2 to f16x2.";
575
576 return success();
577}
578
579LogicalResult ConvertF4x2ToF16x2Op::verify() {
581
582 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
583 return emitOpError("Only ")
584 << mlir::Float4E2M1FNType::get(ctx)
585 << " type is supported for conversions from f4x2 to f16x2.";
586
587 return success();
588}
589
590LogicalResult PermuteOp::verify() {
591 using Mode = NVVM::PermuteMode;
592 bool hasHi = static_cast<bool>(getHi());
593
594 switch (getMode()) {
595 case Mode::DEFAULT:
596 case Mode::F4E:
597 case Mode::B4E:
598 if (!hasHi)
599 return emitError("mode '") << getMode() << "' requires 'hi' operand.";
600 break;
601 case Mode::RC8:
602 case Mode::ECL:
603 case Mode::ECR:
604 case Mode::RC16:
605 if (hasHi)
606 return emitError("mode '")
607 << getMode() << "' does not accept 'hi' operand.";
608 break;
609 }
610
611 return success();
612}
613
614//===----------------------------------------------------------------------===//
615// Stochastic Rounding Conversion Ops
616//===----------------------------------------------------------------------===//
617
618static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
619 FPRoundingMode rnd,
620 bool hasRandomBits,
621 Operation *op) {
622 static constexpr FPRoundingMode validRndModes[] = {
623 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
624
625 if (!llvm::is_contained(validRndModes, rnd)) {
626 return op->emitOpError(
627 "Only RN, RZ, and RS rounding modes are supported for "
628 "conversions from f32x2 to ")
629 << dstType << ".";
630 }
631
632 if (rnd == FPRoundingMode::RS) {
633 if (!hasRandomBits) {
634 return op->emitOpError("random_bits is required for RS rounding mode.");
635 }
636 } else {
637 if (hasRandomBits) {
638 return op->emitOpError(
639 "random_bits not supported for RN and RZ rounding modes.");
640 }
641 }
642
643 return success();
644}
645
646LogicalResult ConvertF32x2ToF16x2Op::verify() {
647 return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
648 getRandomBits() ? true : false, *this);
649}
650
651LogicalResult ConvertF32x2ToBF16x2Op::verify() {
652 return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
653 getRandomBits() ? true : false, *this);
654}
655
656LogicalResult ConvertF32x4ToF8x4Op::verify() {
658
659 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
660 return emitOpError("Only ")
661 << mlir::Float8E4M3FNType::get(ctx) << " and "
662 << mlir::Float8E5M2Type::get(ctx)
663 << " types are supported for conversions from f32x4 to f8x4.";
664
665 return success();
666}
667
668LogicalResult ConvertF32x4ToF6x4Op::verify() {
670
671 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
672 return emitOpError("Only ")
673 << mlir::Float6E2M3FNType::get(ctx) << " and "
674 << mlir::Float6E3M2FNType::get(ctx)
675 << " types are supported for conversions from f32x4 to f6x4.";
676
677 return success();
678}
679
680LogicalResult ConvertF32x4ToF4x4Op::verify() {
682
683 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
684 return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
685 << " type is supported for conversions from "
686 "f32x4 to f4x4.";
687
688 return success();
689}
690
691LogicalResult BulkStoreOp::verify() {
692 if (getInitVal() != 0)
693 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
694 return success();
695}
696
697LogicalResult PMEventOp::verify() {
698 auto eventId = getEventId();
699 auto maskedEventId = getMaskedEventId();
700 if (!maskedEventId && !eventId) {
701 return emitOpError() << "either `id` or `mask` must be set";
702 }
703
704 if (maskedEventId && eventId) {
705 return emitOpError() << "`id` and `mask` cannot be set at the same time";
706 }
707
708 if (eventId) {
709 if (eventId < 0 || eventId > 15) {
710 return emitOpError() << "`id` must be between 0 and 15";
711 }
712 }
713
714 return llvm::success();
715}
716
717// Given the element type of an operand and whether or not it is an accumulator,
718// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
719// operand's element type.
720std::optional<mlir::NVVM::MMATypes>
721MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
722 auto half2Type =
723 VectorType::get(2, Float16Type::get(operandElType.getContext()));
724 if (operandElType.isF64())
725 return NVVM::MMATypes::f64;
726 if (operandElType.isF16() || operandElType == half2Type)
727 return NVVM::MMATypes::f16;
728 if (operandElType.isF32() && isAccumulator)
729 return NVVM::MMATypes::f32;
730 if (operandElType.isF32() && !isAccumulator)
731 return NVVM::MMATypes::tf32;
732 if (llvm::isa<IntegerType>(operandElType)) {
733 if (isAccumulator)
734 return NVVM::MMATypes::s32;
735 return std::nullopt;
736 }
737
738 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
739 if (structType.getBody().empty())
740 return std::nullopt;
741 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
742 }
743
744 return std::nullopt;
745}
746
747static bool isInt4PtxType(MMATypes type) {
748 return (type == MMATypes::u4 || type == MMATypes::s4);
749}
750
751static bool isInt8PtxType(MMATypes type) {
752 return (type == MMATypes::u8 || type == MMATypes::s8);
753}
754
755static bool isIntegerPtxType(MMATypes type) {
756 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
757 type == MMATypes::s32;
758}
759
760MMATypes MmaOp::accumPtxType() {
761 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
762 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
763 assert(val.has_value() && "accumulator PTX type should always be inferrable");
764 return val.value();
765}
766
767MMATypes MmaOp::resultPtxType() {
768 std::optional<mlir::NVVM::MMATypes> val =
769 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
770 assert(val.has_value() && "result PTX type should always be inferrable");
771 return val.value();
772}
773
774void MmaOp::print(OpAsmPrinter &p) {
775 SmallVector<Type, 4> regTypes;
776 struct MMAOperandFragment {
777 StringRef operandName;
778 StringRef ptxTypeAttr;
779 SmallVector<Value, 4> regs;
780 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
781 : operandName(name), ptxTypeAttr(ptxTypeName) {}
782 };
783
784 std::array<MMAOperandFragment, 3> frags{
785 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
786 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
787 MMAOperandFragment("C", "")};
788 SmallVector<StringRef, 4> ignoreAttrNames{
789 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
790
791 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
792 auto &frag = frags[fragIdx];
793 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
794 for (auto operandIdx = varOperandSpec.first;
795 operandIdx < varOperandSpec.first + varOperandSpec.second;
796 operandIdx++) {
797 frag.regs.push_back(this->getOperand(operandIdx));
798 if (operandIdx == 0) {
799 regTypes.push_back(this->getOperand(operandIdx).getType());
800 }
801 }
802 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
803 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
804 if (inferredType)
805 ignoreAttrNames.push_back(frag.ptxTypeAttr);
806 }
807
808 auto printMmaOperand = [&](const MMAOperandFragment &frag) -> void {
809 p << " " << frag.operandName;
810 p << "[";
811 p.printOperands(frag.regs);
812 p << "] ";
813 };
814
815 for (const auto &frag : frags) {
816 printMmaOperand(frag);
817 }
818
819 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
820
821 // Print the types of the operands and result.
822 p << " : "
823 << "(";
824 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
825 frags[1].regs[0].getType(),
826 frags[2].regs[0].getType()},
827 p);
828 p << ")";
829 p.printArrowTypeList(TypeRange{this->getRes().getType()});
830}
831
832void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
833 ValueRange operandA, ValueRange operandB, ValueRange operandC,
834 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
835 std::optional<MMAIntOverflow> intOverflow,
836 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
837 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
838
839 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
840 MLIRContext *ctx = builder.getContext();
841 result.addAttribute(
842 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
843
844 result.addOperands(operandA);
845 result.addOperands(operandB);
846 result.addOperands(operandC);
847
848 if (multiplicandPtxTypes) {
849 result.addAttribute("multiplicandAPtxType",
850 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
851 result.addAttribute("multiplicandBPtxType",
852 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
853 } else {
854 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
855 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
856 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
857 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
858 }
859
860 if (multiplicandLayouts) {
861 result.addAttribute("layoutA",
862 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
863 result.addAttribute("layoutB",
864 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
865 } else {
866 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
867 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
868 }
869
870 if (intOverflow.has_value())
871 result.addAttribute("intOverflowBehavior",
872 MMAIntOverflowAttr::get(ctx, *intOverflow));
873 if (b1Op.has_value())
874 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
875
876 result.addTypes(resultType);
877 result.addAttribute(
878 MmaOp::getOperandSegmentSizeAttr(),
879 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
880 static_cast<int32_t>(operandB.size()),
881 static_cast<int32_t>(operandC.size())}));
882}
883
884// <operation> :=
885// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
886// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
887// `->` type($res)
888ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
889 struct MMAOperandFragment {
890 std::optional<MMATypes> elemtype;
891 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
892 SmallVector<Type> regTypes;
893 };
894
895 Builder &builder = parser.getBuilder();
896 std::array<MMAOperandFragment, 4> frags;
897
898 NamedAttrList namedAttributes;
899
900 // A helper to parse the operand segments.
901 auto parseMmaOperand = [&](StringRef operandName,
902 MMAOperandFragment &frag) -> LogicalResult {
903 if (parser.parseKeyword(operandName).failed())
904 return failure();
905 if (parser
906 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
907 .failed())
908 return failure();
909 return success();
910 };
911
912 // Parse the operand segments.
913 if (parseMmaOperand("A", frags[0]).failed())
914 return failure();
915 if (parseMmaOperand("B", frags[1]).failed())
916 return failure();
917 if (parseMmaOperand("C", frags[2]).failed())
918 return failure();
919
920 if (parser.parseOptionalAttrDict(namedAttributes).failed())
921 return failure();
922
923 // Parse the type specification and resolve operands.
924 SmallVector<Type, 3> operandTypes;
925 if (failed(parser.parseColon()))
926 return failure();
927 if (failed(parser.parseLParen()))
928 return failure();
929 if (failed(parser.parseTypeList(operandTypes)))
930 return failure();
931 if (failed(parser.parseRParen()))
932 if (operandTypes.size() != 3)
933 return parser.emitError(
934 parser.getNameLoc(),
935 "expected one type for each operand segment but got " +
936 Twine(operandTypes.size()) + " types");
937 for (const auto &iter : llvm::enumerate(operandTypes)) {
938 auto &frag = frags[iter.index()];
939 frag.regTypes.resize(frag.regs.size(), iter.value());
940 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
941 parser.getNameLoc(), result.operands)))
942 return failure();
943 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
944 /*isAccumulator*/ iter.index() < 2);
945 }
946
947 Type resultType;
948 if (parser.parseArrow() || parser.parseType(resultType))
949 return failure();
950 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
951
952 std::array<StringRef, 2> names{"multiplicandAPtxType",
953 "multiplicandBPtxType"};
954 for (unsigned idx = 0; idx < names.size(); idx++) {
955 const auto &frag = frags[idx];
956 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
957 if (!frag.elemtype.has_value() && !attr.has_value()) {
958 return parser.emitError(
959 parser.getNameLoc(),
960 "attribute " + names[idx] +
961 " is not provided explicitly and cannot be inferred");
962 }
963 if (!attr.has_value())
964 result.addAttribute(
965 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
966 }
967
968 result.addTypes(resultType);
969 if (!namedAttributes.empty())
970 result.addAttributes(namedAttributes);
971 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
972 builder.getDenseI32ArrayAttr({
973 static_cast<int32_t>(frags[0].regs.size()),
974 static_cast<int32_t>(frags[1].regs.size()),
975 static_cast<int32_t>(frags[2].regs.size()),
976 }));
977 return success();
978}
979
980LogicalResult MmaOp::verify() {
981 MLIRContext *context = getContext();
982 auto f16Ty = Float16Type::get(context);
983 auto i32Ty = IntegerType::get(context, 32);
984 auto f16x2Ty = VectorType::get(2, f16Ty);
985 auto f32Ty = Float32Type::get(context);
986 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
987 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
988
989 auto s32x4StructTy =
990 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
991 auto f32x8StructTy =
992 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
993 auto f16x2x2StructTy =
994 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
995 auto f32x4StructTy =
996 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
997 auto s32x2StructTy =
998 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
999
1000 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1001 getShapeAttr().getK()};
1002
1003 // These variables define the set of allowed data types for matrices A, B, C,
1004 // and result.
1005 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
1006 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
1007 AllowedShapes allowedShapes;
1008 AllowedTypes expectedA;
1009 AllowedTypes expectedB;
1010 AllowedTypes expectedC;
1011 SmallVector<Type> expectedResult;
1012
1013 // When M = 16, we just need to calculate the number of 8xk tiles, where
1014 // k is a factor that depends on the data type.
1015 if (mmaShape[0] == 16) {
1016 int64_t kFactor;
1017 Type multiplicandFragType;
1018 switch (*getMultiplicandAPtxType()) {
1019 case MMATypes::tf32:
1020 kFactor = 4;
1021 multiplicandFragType = i32Ty;
1022 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1023 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1024 break;
1025 case MMATypes::bf16:
1026 kFactor = 8;
1027 multiplicandFragType = i32Ty;
1028 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1029 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1030 break;
1031 case MMATypes::f16:
1032 kFactor = 8;
1033 multiplicandFragType = f16x2Ty;
1034 expectedResult.push_back(f16x2x2StructTy);
1035 expectedResult.push_back(f32x4StructTy);
1036 break;
1037 case MMATypes::s4:
1038 case MMATypes::u4:
1039 kFactor = 32;
1040 break;
1041 case MMATypes::b1:
1042 kFactor = 128;
1043 break;
1044 case MMATypes::s8:
1045 case MMATypes::u8:
1046 kFactor = 16;
1047 break;
1048 default:
1049 return emitError("invalid shape or multiplicand type: ")
1050 << getMultiplicandAPtxType().value();
1051 }
1052
1053 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1054 expectedResult.push_back(s32x4StructTy);
1055 expectedC.emplace_back(4, i32Ty);
1056 multiplicandFragType = i32Ty;
1057 } else {
1058 expectedC.emplace_back(2, f16x2Ty);
1059 expectedC.emplace_back(4, f32Ty);
1060 }
1061
1062 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
1063 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1064 expectedA.emplace_back(unitA, multiplicandFragType);
1065 expectedB.emplace_back(unitB, multiplicandFragType);
1066 allowedShapes.push_back({16, 8, kFactor});
1067 allowedShapes.push_back({16, 8, kFactor * 2});
1068
1069 if (resultPtxType() != accumPtxType())
1070 return emitOpError("ctype does not match dtype");
1071 }
1072
1073 // In the M=8 case, there is only 1 possible case per data type.
1074 if (mmaShape[0] == 8) {
1075 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1076 expectedA.emplace_back(2, f16x2Ty);
1077 expectedB.emplace_back(2, f16x2Ty);
1078 expectedResult.push_back(f16x2x4StructTy);
1079 expectedResult.push_back(f32x8StructTy);
1080 expectedC.emplace_back(4, f16x2Ty);
1081 expectedC.emplace_back(8, f32Ty);
1082 allowedShapes.push_back({8, 8, 4});
1083 }
1084 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1085 Type f64Ty = Float64Type::get(context);
1086 expectedA.emplace_back(1, f64Ty);
1087 expectedB.emplace_back(1, f64Ty);
1088 expectedC.emplace_back(2, f64Ty);
1089 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1090 context, SmallVector<Type>(2, f64Ty)));
1091 allowedShapes.push_back({8, 8, 4});
1092 }
1093 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1094 expectedA.push_back({i32Ty});
1095 expectedB.push_back({i32Ty});
1096 expectedC.push_back({i32Ty, i32Ty});
1097 expectedResult.push_back(s32x2StructTy);
1098 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1099 allowedShapes.push_back({8, 8, 32});
1100 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1101 allowedShapes.push_back({8, 8, 16});
1102 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1103 allowedShapes.push_back({8, 8, 128});
1104 }
1105 }
1106
1107 std::string errorMessage;
1108 llvm::raw_string_ostream errorStream(errorMessage);
1109
1110 // Check that we matched an existing shape/dtype combination.
1111 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1112 !llvm::is_contained(allowedShapes, mmaShape)) {
1113 errorStream << "unimplemented variant for MMA shape <";
1114 llvm::interleaveComma(mmaShape, errorStream);
1115 errorStream << ">";
1116 return emitOpError(errorMessage);
1117 }
1118
1119 // Verify the operand types for segments of A, B, and C operands.
1120 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1121 for (const auto &iter : llvm::enumerate(
1122 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1123 auto spec = this->getODSOperandIndexAndLength(iter.index());
1124 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1125 operand_type_begin() + spec.first +
1126 spec.second);
1127 bool match = llvm::is_contained(iter.value(), operandTySeg);
1128
1129 if (!match) {
1130 errorStream << "Could not match types for the "
1131 << operandNames[iter.index()]
1132 << " operands; expected one of ";
1133 for (const auto &x : iter.value()) {
1134 errorStream << x.size() << "x" << x[0] << " ";
1135 }
1136 errorStream << "but got ";
1137 llvm::interleaveComma(operandTySeg, errorStream);
1138 return emitOpError(errorMessage);
1139 }
1140 }
1141
1142 // Check the result type
1143 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1144 return expectedResultType == getResult().getType();
1145 })) {
1146 errorStream
1147 << "Could not match allowed types for the result; expected one of ";
1148 llvm::interleaveComma(expectedResult, errorStream);
1149 errorStream << " but got " << getResult().getType();
1150 return emitOpError(errorMessage);
1151 }
1152
1153 // Ensure that binary MMA variants have a b1 MMA operation defined.
1154 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1155 return emitOpError("op requires " + getB1OpAttrName().strref() +
1156 " attribute");
1157 }
1158
1159 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1160 // attribute.
1161 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1162 isInt8PtxType(*getMultiplicandAPtxType())) {
1163 if (!getIntOverflowBehavior())
1164 return emitOpError("op requires " +
1165 getIntOverflowBehaviorAttrName().strref() +
1166 " attribute");
1167 }
1168
1169 // Validate layout combinations. According to the operation description, most
1170 // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
1171 // can use other layout combinations.
1172 bool isM8N8K4_F16 =
1173 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1174 getMultiplicandAPtxType() == MMATypes::f16);
1175
1176 if (!isM8N8K4_F16) {
1177 // For all other shapes/types, layoutA must be row and layoutB must be col
1178 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1179 return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
1180 "layoutB = #nvvm.mma_layout<col> for shape <")
1181 << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
1182 << "> with element types " << *getMultiplicandAPtxType() << " and "
1183 << *getMultiplicandBPtxType()
1184 << ". Only m8n8k4 with f16 supports other layouts.";
1185 }
1186 }
1187
1188 return success();
1189}
1190
1191MMATypes MmaSpOp::accumPtxType() {
1192 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1193 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
1194 assert(val.has_value() && "accumulator PTX type should always be inferrable");
1195 return val.value();
1196}
1197
1198MMATypes MmaSpOp::resultPtxType() {
1199 std::optional<mlir::NVVM::MMATypes> val =
1200 MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
1201 assert(val.has_value() && "result PTX type should always be inferrable");
1202 return val.value();
1203}
1204
1206MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1207 llvm::IRBuilderBase &builder) {
1208 auto thisOp = cast<NVVM::MmaSpOp>(op);
1209
1210 // Get operands
1212 for (mlir::Value v : thisOp.getOperands())
1213 args.push_back(mt.lookupValue(v));
1214
1215 // Get intrinsic ID using the existing getIntrinsicID method
1216 auto intId = MmaSpOp::getIntrinsicID(
1217 thisOp.getShape().getM(), thisOp.getShape().getN(),
1218 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1219 thisOp.getOrderedMetadata(), thisOp.getKind(),
1220 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1221 thisOp.accumPtxType(), thisOp.resultPtxType());
1222
1223 return {intId, args};
1224}
1225
1226void MmaSpOp::print(OpAsmPrinter &p) {
1227 SmallVector<Type, 4> regTypes;
1228 struct MMAOperandFragment {
1229 StringRef operandName;
1230 StringRef ptxTypeAttr;
1231 SmallVector<Value, 4> regs;
1232 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1233 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1234 };
1235
1236 std::array<MMAOperandFragment, 5> frags{
1237 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1238 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1239 MMAOperandFragment("C", ""), MMAOperandFragment("sparseMetadata", ""),
1240 MMAOperandFragment("selector", "")};
1241 SmallVector<StringRef, 4> ignoreAttrNames{
1242 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1243
1244 // Handle variadic operands A, B, C
1245 for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1246 auto &frag = frags[fragIdx];
1247 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1248 for (auto operandIdx = varOperandSpec.first;
1249 operandIdx < varOperandSpec.first + varOperandSpec.second;
1250 operandIdx++) {
1251 frag.regs.push_back(this->getOperand(operandIdx));
1252 if (operandIdx == varOperandSpec.first) {
1253 regTypes.push_back(this->getOperand(operandIdx).getType());
1254 }
1255 }
1256 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1257 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
1258 if (inferredType)
1259 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1260 }
1261
1262 // Handle sparse metadata and selector (single operands)
1263 frags[3].regs.push_back(getSparseMetadata());
1264 frags[4].regs.push_back(getSparsitySelector());
1265
1266 auto printMmaSpOperand = [&](const MMAOperandFragment &frag) -> void {
1267 p << " " << frag.operandName;
1268 p << "[";
1269 p.printOperands(frag.regs);
1270 p << "]";
1271 };
1272
1273 for (const auto &frag : frags)
1274 printMmaSpOperand(frag);
1275
1276 p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
1277 p << " : ";
1278 p << "(";
1279 for (int i = 0; i < 3; ++i) {
1280 p << regTypes[i];
1281 if (i < 2)
1282 p << ", ";
1283 }
1284 p << ") -> " << getResult().getType();
1285}
1286
1287void MmaSpOp::build(
1288 OpBuilder &builder, OperationState &result, Type resultType,
1289 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1290 Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
1291 std::optional<MMAIntOverflow> intOverflow,
1292 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1293
1294 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1295 MLIRContext *ctx = builder.getContext();
1296 result.addAttribute(
1297 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1298
1299 result.addOperands(operandA);
1300 result.addOperands(operandB);
1301 result.addOperands(operandC);
1302 result.addOperands(sparseMetadata);
1303 result.addOperands(sparsitySelector);
1304
1305 if (multiplicandPtxTypes) {
1306 result.addAttribute("multiplicandAPtxType",
1307 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1308 result.addAttribute("multiplicandBPtxType",
1309 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1310 } else {
1311 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1312 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1313 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1314 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1315 }
1316
1317 if (intOverflow.has_value())
1318 result.addAttribute("intOverflowBehavior",
1319 MMAIntOverflowAttr::get(ctx, *intOverflow));
1320
1321 result.addTypes(resultType);
1322 result.addAttribute(
1323 MmaSpOp::getOperandSegmentSizeAttr(),
1324 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
1325 static_cast<int32_t>(operandB.size()),
1326 static_cast<int32_t>(operandC.size()), 1,
1327 1})); // sparseMetadata and sparsitySelector
1328}
1329
1330ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
1331 struct MMAOperandFragment {
1332 std::optional<MMATypes> elemtype;
1333 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1334 SmallVector<Type> regTypes;
1335 };
1336
1337 Builder &builder = parser.getBuilder();
1338 std::array<MMAOperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
1339
1340 NamedAttrList namedAttributes;
1341
1342 // A helper to parse the operand segments.
1343 auto parseMmaSpOperand = [&](StringRef operandName,
1344 MMAOperandFragment &frag) -> LogicalResult {
1345 if (parser.parseKeyword(operandName).failed())
1346 return failure();
1347 if (parser
1348 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
1349 .failed())
1350 return failure();
1351 return success();
1352 };
1353
1354 // Parse the operand segments.
1355 if (parseMmaSpOperand("A", frags[0]).failed())
1356 return failure();
1357 if (parseMmaSpOperand("B", frags[1]).failed())
1358 return failure();
1359 if (parseMmaSpOperand("C", frags[2]).failed())
1360 return failure();
1361 if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
1362 return failure();
1363 if (parseMmaSpOperand("selector", frags[4]).failed())
1364 return failure();
1365
1366 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1367 return failure();
1368
1369 // Parse the type specification and resolve operands.
1370 SmallVector<Type, 3> operandTypes;
1371 if (failed(parser.parseColon()))
1372 return failure();
1373 if (failed(parser.parseLParen()))
1374 return failure();
1375 if (failed(parser.parseTypeList(operandTypes)))
1376 return failure();
1377 if (failed(parser.parseRParen()))
1378 return failure();
1379 if (operandTypes.size() != 3)
1380 return parser.emitError(
1381 parser.getNameLoc(),
1382 "expected one type for each operand segment but got " +
1383 Twine(operandTypes.size()) + " types");
1384 for (const auto &iter : llvm::enumerate(operandTypes)) {
1385 auto &frag = frags[iter.index()];
1386 frag.regTypes.resize(frag.regs.size(), iter.value());
1387 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
1388 parser.getNameLoc(), result.operands)))
1389 return failure();
1390 frag.elemtype =
1391 MmaOp::inferOperandMMAType(frag.regTypes[0],
1392 /*isAccumulator*/ iter.index() >= 2);
1393 }
1394
1395 Type resultType;
1396 if (parser.parseArrow() || parser.parseType(resultType))
1397 return failure();
1398 frags[5].elemtype =
1399 MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
1400
1401 // Resolve sparse metadata and selector (assume i32 type)
1402 Type i32Type = builder.getIntegerType(32);
1403 if (parser
1404 .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
1405 result.operands)
1406 .failed())
1407 return failure();
1408 if (parser
1409 .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
1410 result.operands)
1411 .failed())
1412 return failure();
1413
1414 std::array<StringRef, 2> names{"multiplicandAPtxType",
1415 "multiplicandBPtxType"};
1416 for (unsigned idx = 0; idx < names.size(); idx++) {
1417 const auto &frag = frags[idx];
1418 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
1419 if (!frag.elemtype.has_value() && !attr.has_value()) {
1420 return parser.emitError(
1421 parser.getNameLoc(),
1422 "attribute " + names[idx] +
1423 " is not provided explicitly and cannot be inferred");
1424 }
1425 if (!attr.has_value())
1426 result.addAttribute(
1427 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
1428 }
1429
1430 result.addTypes(resultType);
1431 if (!namedAttributes.empty())
1432 result.addAttributes(namedAttributes);
1433 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1434 builder.getDenseI32ArrayAttr({
1435 static_cast<int32_t>(frags[0].regs.size()),
1436 static_cast<int32_t>(frags[1].regs.size()),
1437 static_cast<int32_t>(frags[2].regs.size()),
1438 1, // sparseMetadata
1439 1 // sparsitySelector
1440 }));
1441 return success();
1442}
1443
1444LogicalResult MmaSpOp::verify() {
1445 MLIRContext *context = getContext();
1446 auto f16Ty = Float16Type::get(context);
1447 auto i32Ty = IntegerType::get(context, 32);
1448 auto f16x2Ty = VectorType::get(2, f16Ty);
1449 auto f32Ty = Float32Type::get(context);
1450 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1451 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1452
1453 auto s32x4StructTy =
1454 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1455 auto f32x8StructTy =
1456 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
1457 auto f16x2x2StructTy =
1458 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1459 auto f32x4StructTy =
1460 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1461 auto s32x2StructTy =
1462 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1463
1464 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1465 getShapeAttr().getK()};
1466
1467 // These variables define the set of allowed data types for matrices A, B, C,
1468 // and result.
1469 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
1470 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
1471 AllowedShapes allowedShapes;
1472 AllowedTypes expectedA;
1473 AllowedTypes expectedB;
1474 AllowedTypes expectedC;
1475 SmallVector<Type> expectedResult;
1476
1477 // When M = 16, we just need to calculate the number of 8xk tiles, where
1478 // k is a factor that depends on the data type.
1479 if (mmaShape[0] == 16) {
1480 int64_t kFactor;
1481 Type multiplicandFragType;
1482 switch (*getMultiplicandAPtxType()) {
1483 case MMATypes::tf32:
1484 kFactor = 4;
1485 multiplicandFragType = i32Ty;
1486 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1487 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1488 // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
1489 allowedShapes.push_back({16, 8, 8});
1490 allowedShapes.push_back({16, 8, 16});
1491 break;
1492 case MMATypes::bf16:
1493 kFactor = 8;
1494 multiplicandFragType = i32Ty;
1495 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1496 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1497 // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
1498 allowedShapes.push_back({16, 8, 16});
1499 allowedShapes.push_back({16, 8, 32});
1500 break;
1501 case MMATypes::f16:
1502 kFactor = 8;
1503 multiplicandFragType = f16x2Ty;
1504 expectedResult.push_back(f16x2x2StructTy);
1505 expectedResult.push_back(f32x4StructTy);
1506 // Sparse MMA supports m16n8k16 and m16n8k32 for f16
1507 allowedShapes.push_back({16, 8, 16});
1508 allowedShapes.push_back({16, 8, 32});
1509 break;
1510 case MMATypes::s4:
1511 case MMATypes::u4:
1512 kFactor = 32;
1513 // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
1514 allowedShapes.push_back({16, 8, 64});
1515 allowedShapes.push_back({16, 8, 128});
1516 break;
1517 case MMATypes::s8:
1518 case MMATypes::u8:
1519 kFactor = 16;
1520 // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
1521 allowedShapes.push_back({16, 8, 32});
1522 allowedShapes.push_back({16, 8, 64});
1523 break;
1524 case MMATypes::e4m3:
1525 case MMATypes::e5m2:
1526 case MMATypes::e3m2:
1527 case MMATypes::e2m3:
1528 case MMATypes::e2m1:
1529 kFactor = 16;
1530 multiplicandFragType = i32Ty;
1531 expectedResult.push_back(f16x2x2StructTy);
1532 expectedResult.push_back(f32x4StructTy);
1533 // Sparse MMA supports m16n8k64 for FP8 types
1534 allowedShapes.push_back({16, 8, 64});
1535 break;
1536 default:
1537 return emitError("invalid shape or multiplicand type: ")
1538 << getMultiplicandAPtxType().value();
1539 }
1540
1541 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1542 expectedResult.push_back(s32x4StructTy);
1543 expectedC.emplace_back(4, i32Ty);
1544 multiplicandFragType = i32Ty;
1545 } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1546 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1547 // FP8 types
1548 expectedC.emplace_back(2, f16x2Ty);
1549 expectedC.emplace_back(4, f32Ty);
1550 } else {
1551 expectedC.emplace_back(2, f16x2Ty);
1552 expectedC.emplace_back(4, f32Ty);
1553 }
1554
1555 // For sparse MMA, A operand is compressed (2:4 sparsity means half the
1556 // elements)
1557 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1558 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1559 expectedA.emplace_back(unitA, multiplicandFragType);
1560 expectedB.emplace_back(unitB, multiplicandFragType);
1561
1562 if (resultPtxType() != accumPtxType())
1563 return emitOpError("ctype does not match dtype");
1564 }
1565
1566 // In the M=8 case, there is only 1 possible case per data type.
1567 if (mmaShape[0] == 8) {
1568 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1569 expectedA.emplace_back(2, f16x2Ty);
1570 expectedB.emplace_back(2, f16x2Ty);
1571 expectedResult.push_back(f16x2x4StructTy);
1572 expectedResult.push_back(f32x8StructTy);
1573 expectedC.emplace_back(4, f16x2Ty);
1574 expectedC.emplace_back(8, f32Ty);
1575 allowedShapes.push_back({8, 8, 4});
1576 }
1577 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1578 Type f64Ty = Float64Type::get(context);
1579 expectedA.emplace_back(1, f64Ty);
1580 expectedB.emplace_back(1, f64Ty);
1581 expectedC.emplace_back(2, f64Ty);
1582 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1583 context, SmallVector<Type>(2, f64Ty)));
1584 allowedShapes.push_back({8, 8, 4});
1585 }
1586 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1587 expectedA.push_back({i32Ty});
1588 expectedB.push_back({i32Ty});
1589 expectedC.push_back({i32Ty, i32Ty});
1590 expectedResult.push_back(s32x2StructTy);
1591 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1592 allowedShapes.push_back({8, 8, 32});
1593 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1594 allowedShapes.push_back({8, 8, 16});
1595 }
1596 }
1597
1598 std::string errorMessage;
1599 llvm::raw_string_ostream errorStream(errorMessage);
1600
1601 // Check that we matched an existing shape/dtype combination.
1602 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1603 !llvm::is_contained(allowedShapes, mmaShape)) {
1604 errorStream << "unimplemented variant for MMA shape <";
1605 llvm::interleaveComma(mmaShape, errorStream);
1606 errorStream << ">";
1607 return emitOpError(errorMessage);
1608 }
1609
1610 // Verify the operand types for segments of A, B, and C operands.
1611 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1612 for (const auto &iter : llvm::enumerate(
1613 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1614 auto spec = this->getODSOperandIndexAndLength(iter.index());
1615 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1616 operand_type_begin() + spec.first +
1617 spec.second);
1618 bool match = llvm::is_contained(iter.value(), operandTySeg);
1619
1620 if (!match) {
1621 errorStream << "Could not match types for the "
1622 << operandNames[iter.index()]
1623 << " operands; expected one of ";
1624 for (const auto &x : iter.value()) {
1625 errorStream << x.size() << "x" << x[0] << " ";
1626 }
1627 errorStream << "but got ";
1628 llvm::interleaveComma(operandTySeg, errorStream);
1629 return emitOpError(errorMessage);
1630 }
1631 }
1632
1633 // Check the result type
1634 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1635 return expectedResultType == getResult().getType();
1636 })) {
1637 errorStream
1638 << "Could not match allowed types for the result; expected one of ";
1639 llvm::interleaveComma(expectedResult, errorStream);
1640 errorStream << " but got " << getResult().getType();
1641 return emitOpError(errorMessage);
1642 }
1643
1644 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1645 // attribute.
1646 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1647 isInt8PtxType(*getMultiplicandAPtxType())) {
1648 if (!getIntOverflowBehavior())
1649 return emitOpError("op requires " +
1650 getIntOverflowBehaviorAttrName().strref() +
1651 " attribute");
1652 }
1653
1654 // Validate sparse metadata type (should be i32)
1655 if (!getSparseMetadata().getType().isInteger(32)) {
1656 return emitOpError() << "sparse metadata must be i32 type";
1657 }
1658
1659 // Validate sparsity selector type (should be i32)
1660 if (!getSparsitySelector().getType().isInteger(32)) {
1661 return emitOpError() << "sparsity selector must be i32 type";
1662 }
1663
1664 return success();
1665}
1666
1667//===----------------------------------------------------------------------===//
1668// MMA Block Scale Operations - Shared Helpers
1669//===----------------------------------------------------------------------===//
1670
1671namespace {
1672// Shared structure for MMA operand fragments (A, B, C)
1673struct MMAOperandFragment {
1674 StringRef operandName;
1675 StringRef ptxTypeAttr;
1676 SmallVector<Value, 4> regs;
1677 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1678 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1679};
1680} // namespace
1681
1682// Helper to print operand list in the format: name[operands]
1683static void printOperandList(OpAsmPrinter &p, StringRef name,
1684 ArrayRef<Value> operands) {
1685 p << " " << name << "[";
1686 p.printOperands(operands);
1687 p << "]";
1688}
1689
1690// Helper to parse operand list in the format: name[operands]
1691static LogicalResult
1692parseMmaOperand(OpAsmParser &parser, StringRef operandName,
1694 if (parser.parseKeyword(operandName).failed())
1695 return failure();
1697 .failed())
1698 return failure();
1699 return success();
1700}
1701
1702// Helper to process operand fragments and determine which attributes can be
1703// inferred
1704template <typename Op>
1705static void
1706processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
1707 SmallVectorImpl<Type> &regTypes,
1708 SmallVectorImpl<StringRef> &ignoreAttrNames) {
1709 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1710 auto &frag = frags[fragIdx];
1711 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1712 for (auto operandIdx = varOperandSpec.first;
1713 operandIdx < varOperandSpec.first + varOperandSpec.second;
1714 operandIdx++) {
1715 frag.regs.push_back(op.getOperand(operandIdx));
1716 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1717 regTypes.push_back(op.getOperand(operandIdx).getType());
1718 }
1719 }
1720 if (fragIdx < 2) {
1721 regTypes.push_back(frag.regs[0].getType());
1722 }
1723 std::optional<MMATypes> inferredType =
1724 MmaOp::inferOperandMMAType(regTypes.back(),
1725 /*isAccumulator=*/fragIdx >= 2);
1726 if (inferredType)
1727 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1728 }
1729}
1730
1731// Helper to parse type signature: (A_type, B_type, C_type)
1732static LogicalResult
1734 SmallVectorImpl<Type> &operandTypes) {
1735 if (parser.parseColon().failed() || parser.parseLParen().failed())
1736 return failure();
1737
1738 auto typeParser = [&]() {
1739 Type ty;
1740 if (parser.parseType(ty).failed())
1741 return failure();
1742 operandTypes.push_back(ty);
1743 return success();
1744 };
1745 if (parser.parseCommaSeparatedList(typeParser))
1746 return failure();
1747
1748 if (operandTypes.size() != 3)
1749 return parser.emitError(parser.getCurrentLocation(),
1750 "expected exactly 3 types");
1751
1752 return parser.parseRParen();
1753}
1754
1755// Helper to infer and set multiplicand PTX type attributes
1756static void
1758 const SmallVectorImpl<Type> &operandTypes) {
1759 if (!attrs.get("multiplicandAPtxType")) {
1760 if (auto inferredType =
1761 MmaOp::inferOperandMMAType(operandTypes[0], false)) {
1762 attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1763 }
1764 }
1765 if (!attrs.get("multiplicandBPtxType")) {
1766 if (auto inferredType =
1767 MmaOp::inferOperandMMAType(operandTypes[1], false)) {
1768 attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1769 }
1770 }
1771}
1772
1773// Helper to add common block scale properties
1774template <typename OpType>
1777 ScaleVecSize scaleVecSize,
1778 BlockScaleFormat blockScaleFormat,
1779 MMABlockScaleKind kind) {
1780 MLIRContext *ctx = builder.getContext();
1781 auto &properties = result.getOrAddProperties<typename OpType::Properties>();
1782 properties.setShape(
1783 builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1784 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1785 properties.setBlockScaleFormat(
1786 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1787 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1788}
1789
1790// Helper to infer and add multiplicand PTX types to builder
1793 ValueRange operandB,
1794 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1795 if (multiplicandPtxTypes) {
1796 result.addAttribute("multiplicandAPtxType",
1797 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1798 result.addAttribute("multiplicandBPtxType",
1799 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1800 } else {
1801 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1802 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1803 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1804 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1805 }
1806}
1807
1808// Template helper for common accumPtxType/resultPtxType implementation
1809template <typename OpTy>
1810static MMATypes inferPtxTypeFromResult(OpTy op) {
1811 return *MmaOp::inferOperandMMAType(
1812 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1813 /*isAccumulator=*/true);
1814}
1815
1816//===----------------------------------------------------------------------===//
1817// MmaBlockScaleOp
1818//===----------------------------------------------------------------------===//
1819
1820void MmaBlockScaleOp::print(OpAsmPrinter &p) {
1821 SmallVector<Type, 4> regTypes;
1822 std::array<MMAOperandFragment, 3> frags{
1823 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1824 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1825 MMAOperandFragment("C", "")};
1826 SmallVector<StringRef, 4> ignoreAttrNames{
1827 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1828
1829 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1830
1831 // Print A, B, C operands
1832 for (const auto &frag : frags)
1833 printOperandList(p, frag.operandName, frag.regs);
1834
1835 // Print scale operands
1836 printOperandList(p, "scaleA",
1837 {getScaleAData(), getByteIdA(), getThreadIdA()});
1838 printOperandList(p, "scaleB",
1839 {getScaleBData(), getByteIdB(), getThreadIdB()});
1840
1841 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1842
1843 // Print type signature
1844 p << " : (";
1845 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1846 frags[1].regs[0].getType(),
1847 frags[2].regs[0].getType()},
1848 p);
1849 p << ")";
1850 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1851}
1852
1853ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
1855 struct LocalOperandFragment {
1856 std::optional<MMATypes> elemtype;
1857 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1858 };
1859
1860 Builder &builder = parser.getBuilder();
1861 std::array<LocalOperandFragment, 3> frags;
1862 NamedAttrList namedAttributes;
1863
1864 // Parse A[...] B[...] C[...]
1865 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1866 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1867 parseMmaOperand(parser, "C", frags[2].regs).failed())
1868 return failure();
1869
1870 // Parse scale operands: scaleA[...] scaleB[...]
1871 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
1872 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
1873 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
1874 return failure();
1875
1876 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1877 return failure();
1878
1879 // Parse type signature
1880 SmallVector<Type, 3> operandTypes;
1881 if (parseMmaTypeSignature(parser, operandTypes).failed())
1882 return failure();
1883
1884 // Parse result type
1885 SmallVector<Type, 1> resultTypes;
1886 if (parser.parseArrowTypeList(resultTypes).failed())
1887 return failure();
1888
1889 // Infer element types and resolve operands
1890 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
1891 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1892 /*isAccumulator=*/idx >= 2);
1893 if (parser
1894 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
1895 result.operands)
1896 .failed())
1897 return failure();
1898 }
1899
1900 // Resolve scale operands
1901 SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
1902 builder.getI16Type()};
1903 if (parser
1904 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
1905 result.operands)
1906 .failed() ||
1907 parser
1908 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
1909 result.operands)
1910 .failed())
1911 return failure();
1912
1913 // Add attributes
1914 result.addAttributes(namedAttributes);
1915 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
1916 operandTypes);
1917
1918 result.addTypes(resultTypes);
1919 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1920 builder.getDenseI32ArrayAttr({
1921 static_cast<int32_t>(frags[0].regs.size()),
1922 static_cast<int32_t>(frags[1].regs.size()),
1923 static_cast<int32_t>(frags[2].regs.size()),
1924 1, // scaleAData
1925 1, // byteIdA
1926 1, // threadIdA
1927 1, // scaleBData
1928 1, // byteIdB
1929 1 // threadIdB
1930 }));
1931 return success();
1932}
1933
1934void MmaBlockScaleOp::build(
1935 OpBuilder &builder, OperationState &result, Type resultType,
1936 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1937 Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
1938 Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
1939 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1940 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1941 MMABlockScaleKind kind) {
1942 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1943
1945 blockScaleFormat, kind);
1946
1947 result.addOperands(operandA);
1948 result.addOperands(operandB);
1949 result.addOperands(operandC);
1950 result.addOperands(
1951 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1952
1953 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
1954 multiplicandPtxTypes);
1955
1956 result.addTypes(resultType);
1957 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1958 builder.getDenseI32ArrayAttr({
1959 static_cast<int32_t>(operandA.size()),
1960 static_cast<int32_t>(operandB.size()),
1961 static_cast<int32_t>(operandC.size()),
1962 1, // scaleAData
1963 1, // byteIdA
1964 1, // threadIdA
1965 1, // scaleBData
1966 1, // byteIdB
1967 1 // threadIdB
1968 }));
1969}
1970
1971NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
1972 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1973 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1974
1976 // Add A, B, C operands
1977 for (Value operand : curOp.getOperandA())
1978 args.push_back(mt.lookupValue(operand));
1979 for (Value operand : curOp.getOperandB())
1980 args.push_back(mt.lookupValue(operand));
1981 for (Value operand : curOp.getOperandC())
1982 args.push_back(mt.lookupValue(operand));
1983
1984 // Add scale operands
1985 args.push_back(mt.lookupValue(curOp.getScaleAData()));
1986 args.push_back(mt.lookupValue(curOp.getByteIdA()));
1987 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
1988 args.push_back(mt.lookupValue(curOp.getScaleBData()));
1989 args.push_back(mt.lookupValue(curOp.getByteIdB()));
1990 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
1991
1992 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1993 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1994 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1995 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
1996 curOp.getBlockScaleFormat(), curOp.getKind());
1997
1998 return {intId, args};
1999}
2000
2001LogicalResult MmaBlockScaleOp::verify() {
2002 LogicalResult result = success();
2003 int m = getShape().getM();
2004 int n = getShape().getN();
2005 int k = getShape().getK();
2006
2007 if (m == 16 && n == 8 && k == 64) {
2008 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2009 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2011 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
2012 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2013 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2015 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
2016 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2018 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
2019 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2020 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2021 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2022 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2023 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2024 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2025 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
2026 "attributes for mma.m16n8k64.mxf4nvf4");
2027 } else {
2028 result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
2029 }
2030 } else if (m == 16 && n == 8 && k == 32) {
2031 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2032 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2033 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2034 result =
2035 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
2036 "attributes for mma.m16n8k32");
2037 } else {
2038 result = emitOpError("unsupported Geom for mma with block scaling");
2039 }
2040 return result;
2041}
2042
2043//===----------------------------------------------------------------------===//
2044// MmaSpBlockScaleOp
2045//===----------------------------------------------------------------------===//
2046
2047void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
2048 SmallVector<Type, 4> regTypes;
2049 std::array<MMAOperandFragment, 3> frags{
2050 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
2051 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
2052 MMAOperandFragment("C", "")};
2053 SmallVector<StringRef, 4> ignoreAttrNames{
2054 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
2055
2056 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
2057
2058 // Print A, B, C operands
2059 for (const auto &frag : frags)
2060 printOperandList(p, frag.operandName, frag.regs);
2061
2062 // Print sparse-specific operands
2063 printOperandList(p, "sparseMetadata", {getSparseMetadata()});
2064 printOperandList(p, "selector", {getSparsitySelector()});
2065
2066 // Print scale operands
2067 printOperandList(p, "scaleA",
2068 {getScaleAData(), getByteIdA(), getThreadIdA()});
2069 printOperandList(p, "scaleB",
2070 {getScaleBData(), getByteIdB(), getThreadIdB()});
2071
2072 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
2073
2074 // Print type signature
2075 p << " : (";
2076 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
2077 frags[1].regs[0].getType(),
2078 frags[2].regs[0].getType()},
2079 p);
2080 p << ")";
2081 p.printArrowTypeList(TypeRange{this->getRes().getType()});
2082}
2083
2084ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
2086 struct LocalOperandFragment {
2087 std::optional<MMATypes> elemtype;
2088 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
2089 };
2090
2091 Builder &builder = parser.getBuilder();
2092 std::array<LocalOperandFragment, 3> frags;
2093 NamedAttrList namedAttributes;
2094
2095 // Parse A[...] B[...] C[...]
2096 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
2097 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
2098 parseMmaOperand(parser, "C", frags[2].regs).failed())
2099 return failure();
2100
2101 // Parse sparse-specific operands
2103 selectorOperands;
2104 if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
2105 parseMmaOperand(parser, "selector", selectorOperands).failed())
2106 return failure();
2107
2108 // Parse scale operands
2109 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
2110 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
2111 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
2112 return failure();
2113
2114 if (parser.parseOptionalAttrDict(namedAttributes).failed())
2115 return failure();
2116
2117 // Parse type signature
2118 SmallVector<Type, 3> operandTypes;
2119 if (parseMmaTypeSignature(parser, operandTypes).failed())
2120 return failure();
2121
2122 // Parse result type
2123 SmallVector<Type, 1> resultTypes;
2124 if (parser.parseArrowTypeList(resultTypes).failed())
2125 return failure();
2126
2127 // Infer element types and resolve operands
2128 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
2129 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2130 /*isAccumulator=*/idx >= 2);
2131 if (parser
2132 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
2133 result.operands)
2134 .failed())
2135 return failure();
2136 }
2137
2138 // Resolve sparse metadata and selector
2139 Type i32Type = builder.getI32Type();
2140 if (parser
2141 .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
2142 result.operands)
2143 .failed() ||
2144 parser
2145 .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
2146 result.operands)
2147 .failed())
2148 return failure();
2149
2150 // Resolve scale operands
2151 SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
2152 builder.getI16Type()};
2153 if (parser
2154 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
2155 result.operands)
2156 .failed() ||
2157 parser
2158 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
2159 result.operands)
2160 .failed())
2161 return failure();
2162
2163 // Add attributes
2164 result.addAttributes(namedAttributes);
2165 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
2166 operandTypes);
2167
2168 // orderedMetadata is mandatory
2169 if (!result.attributes.get("orderedMetadata"))
2170 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2171
2172 result.addTypes(resultTypes);
2173 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2174 builder.getDenseI32ArrayAttr({
2175 static_cast<int32_t>(frags[0].regs.size()),
2176 static_cast<int32_t>(frags[1].regs.size()),
2177 static_cast<int32_t>(frags[2].regs.size()),
2178 1, // sparseMetadata
2179 1, // sparsitySelector
2180 1, // scaleAData
2181 1, // byteIdA
2182 1, // threadIdA
2183 1, // scaleBData
2184 1, // byteIdB
2185 1 // threadIdB
2186 }));
2187 return success();
2188}
2189
2190void MmaSpBlockScaleOp::build(
2191 OpBuilder &builder, OperationState &result, Type resultType,
2192 ValueRange operandA, ValueRange operandB, ValueRange operandC,
2193 Value sparseMetadata, Value sparsitySelector, Value scaleAData,
2194 Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
2195 Value threadIdB, ArrayRef<int64_t> shape,
2196 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2197 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2198 MMABlockScaleKind kind) {
2199 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
2200
2202 builder, result, shape, scaleVecSize, blockScaleFormat, kind);
2203 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2204
2205 result.addOperands(operandA);
2206 result.addOperands(operandB);
2207 result.addOperands(operandC);
2208 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2209 threadIdA, scaleBData, byteIdB, threadIdB});
2210
2211 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
2212 multiplicandPtxTypes);
2213
2214 result.addTypes(resultType);
2215 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2216 builder.getDenseI32ArrayAttr({
2217 static_cast<int32_t>(operandA.size()),
2218 static_cast<int32_t>(operandB.size()),
2219 static_cast<int32_t>(operandC.size()),
2220 1, // sparseMetadata
2221 1, // sparsitySelector
2222 1, // scaleAData
2223 1, // byteIdA
2224 1, // threadIdA
2225 1, // scaleBData
2226 1, // byteIdB
2227 1 // threadIdB
2228 }));
2229}
2230
2231NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
2232 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2233 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2234
2236 // Add A, B, C operands
2237 for (Value operand : curOp.getOperandA())
2238 args.push_back(mt.lookupValue(operand));
2239 for (Value operand : curOp.getOperandB())
2240 args.push_back(mt.lookupValue(operand));
2241 for (Value operand : curOp.getOperandC())
2242 args.push_back(mt.lookupValue(operand));
2243
2244 // Add sparse metadata and selector
2245 args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
2246 args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
2247
2248 // Add scale operands
2249 args.push_back(mt.lookupValue(curOp.getScaleAData()));
2250 args.push_back(mt.lookupValue(curOp.getByteIdA()));
2251 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
2252 args.push_back(mt.lookupValue(curOp.getScaleBData()));
2253 args.push_back(mt.lookupValue(curOp.getByteIdB()));
2254 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
2255
2256 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2257 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2258 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2259 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
2260 curOp.getBlockScaleFormat(), curOp.getKind());
2261
2262 return {intId, args};
2263}
2264
2265LogicalResult MmaSpBlockScaleOp::verify() {
2266 // Check that orderedMetadata is present
2267 if (!getOrderedMetadata()) {
2268 return emitOpError("'orderedMetadata' attribute is mandatory");
2269 }
2270
2271 LogicalResult result = success();
2272 int m = getShape().getM();
2273 int n = getShape().getN();
2274 int k = getShape().getK();
2275
2276 if (m == 16 && n == 8 && k == 128) {
2277 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2278 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2280 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2281 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2282 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2284 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2285 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2287 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2288 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2289 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2290 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2291 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2292 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2293 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2294 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
2295 "attributes for mma.m16n8k128.mxf4nvf4");
2296 } else {
2297 result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
2298 }
2299 } else if (m == 16 && n == 8 && k == 64) {
2300 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2301 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2302 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2303 result =
2304 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
2305 "attributes for mma.m16n8k64");
2306 } else {
2307 result = emitOpError("unsupported Geom for sparse mma with block scaling");
2308 }
2309 return result;
2310}
2311
2312LogicalResult ShflOp::verify() {
2313 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2314
2315 auto verifyTypeError = [&](Twine desc, Type expectedType,
2316 Type actualType) -> LogicalResult {
2317 return emitOpError("expected " + desc + " to be of type ")
2318 << expectedType << " but got " << actualType << " instead";
2319 };
2320
2321 if (returnStructType) {
2322 if (!getReturnValueAndIsValid())
2323 return emitOpError("\"return_value_and_is_valid\" attribute must be "
2324 "specified when the return type is a struct type");
2325
2326 if (returnStructType.getBody().size() != 2)
2327 return emitOpError("expected return type to be a two-element struct");
2328
2329 llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
2330 auto resultType = returnStruct[0];
2331 if (resultType != getVal().getType())
2332 return verifyTypeError("first element in the returned struct",
2333 getVal().getType(), resultType);
2334
2335 auto predicateType = returnStruct[1];
2336 if (!predicateType.isInteger(1))
2337 return verifyTypeError("second element in the returned struct",
2338 mlir::IntegerType::get(getContext(), 1),
2339 predicateType);
2340 } else {
2341 if (getReturnValueAndIsValid())
2342 return emitOpError("expected return type to be a two-element struct");
2343
2344 if (getType() != getVal().getType())
2345 return verifyTypeError("return type", getVal().getType(), getType());
2346 }
2347 return success();
2348}
2349
2350LogicalResult
2351ShflOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
2352 ShflOp::Adaptor adaptor,
2353 SmallVectorImpl<Type> &inferredReturnTypes) {
2354 Type valType = adaptor.getVal().getType();
2355 if (adaptor.getReturnValueAndIsValid())
2356 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2357 context, {valType, IntegerType::get(context, 1)}));
2358 else
2359 inferredReturnTypes.push_back(valType);
2360 return success();
2361}
2362
2363std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
2364 NVVM::MMAFrag frag, int nRow,
2365 int nCol,
2366 MLIRContext *context) {
2367 unsigned numberElements = 0;
2368 Type elementType;
2369 OpBuilder builder(context);
2370 Type f16x2 = VectorType::get(2, builder.getF16Type());
2371 if (type == NVVM::MMATypes::f16) {
2372 elementType = f16x2;
2373 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2374 numberElements = 8;
2375 else
2376 numberElements = 4;
2377 } else if (type == NVVM::MMATypes::f32) {
2378 elementType = builder.getF32Type();
2379 numberElements = 8;
2380 } else if (type == NVVM::MMATypes::f64) {
2381 elementType = builder.getF64Type();
2382 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2383 numberElements = 1;
2384 else
2385 numberElements = 2;
2386 } else if (type == NVVM::MMATypes::tf32) {
2387 elementType = builder.getI32Type();
2388 numberElements = 4;
2389 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2390 elementType = builder.getI32Type();
2391 int parallelSize = 0;
2392 if (frag == NVVM::MMAFrag::a)
2393 parallelSize = nRow;
2394 if (frag == NVVM::MMAFrag::b)
2395 parallelSize = nCol;
2396
2397 // m == 16 && n == 16 && k == 16
2398 if (parallelSize == 16)
2399 numberElements = 2;
2400 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
2401 else if (parallelSize == 8)
2402 numberElements = 1;
2403 else if (parallelSize == 32)
2404 numberElements = 4;
2405 } else if (type == NVVM::MMATypes::s32) {
2406 elementType = builder.getI32Type();
2407 numberElements = 8;
2408 }
2409 assert(numberElements != 0 && elementType != nullptr);
2410 return std::make_pair(elementType, numberElements);
2411}
2412
2413static std::pair<mlir::Type, unsigned>
2414inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
2415 int k, MLIRContext *context) {
2416 int nRow, nCol;
2417 if (frag == NVVM::MMAFrag::a) {
2418 nRow = m;
2419 nCol = k;
2420 } else if (frag == NVVM::MMAFrag::b) {
2421 nRow = k;
2422 nCol = n;
2423 } else {
2424 nRow = m;
2425 nCol = n;
2426 }
2427 assert(nRow && nCol);
2428 return inferMMAType(type, frag, nRow, nCol, context);
2429}
2430
2431LogicalResult NVVM::WMMALoadOp::verify() {
2432 unsigned addressSpace =
2433 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2434 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2435 addressSpace != NVVMMemorySpace::Shared)
2436 return emitOpError("expected source pointer in memory "
2437 "space 0, 1, 3");
2438
2439 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2440 getEltype(), getFrag()) == 0)
2441 return emitOpError() << "invalid attribute combination";
2442 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2443 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
2444 // Special case for f64 fragments
2445 Type f64Ty = Float64Type::get(getContext());
2446 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2447 if (getType() != f64Ty)
2448 return emitOpError("expected destination type to be f64");
2449 return success();
2450 }
2451 // Everything else is a struct
2452 Type dstType = LLVM::LLVMStructType::getLiteral(
2453 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
2454 if (getType() != dstType)
2455 return emitOpError("expected destination type is a structure of ")
2456 << typeInfo.second << " elements of type " << typeInfo.first;
2457 return success();
2458}
2459
2460LogicalResult NVVM::WMMAStoreOp::verify() {
2461 unsigned addressSpace =
2462 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2463 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2464 addressSpace != NVVMMemorySpace::Shared)
2465 return emitOpError("expected operands to be a source pointer in memory "
2466 "space 0, 1, 3");
2467
2468 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2469 getEltype()) == 0)
2470 return emitOpError() << "invalid attribute combination";
2471 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2472 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2473 if (getArgs().size() != typeInfo.second)
2474 return emitOpError() << "expected " << typeInfo.second << " data operands";
2475 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
2476 return operands.getType() != typeInfo.first;
2477 }))
2478 return emitOpError() << "expected data operands of type " << typeInfo.first;
2479 return success();
2480}
2481
2482LogicalResult NVVM::WMMAMmaOp::verify() {
2483 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
2484 getLayoutB(), getEltypeA(),
2485 getEltypeB()) == 0)
2486 return emitOpError() << "invalid attribute combination";
2487 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
2488 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
2489 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
2490 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
2491 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
2492 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2493 SmallVector<Type, 32> arguments;
2494 arguments.append(typeInfoA.second, typeInfoA.first);
2495 arguments.append(typeInfoB.second, typeInfoB.first);
2496 arguments.append(typeInfoC.second, typeInfoC.first);
2497 unsigned numArgs = arguments.size();
2498 if (getArgs().size() != numArgs)
2499 return emitOpError() << "expected " << numArgs << " arguments";
2500 for (unsigned i = 0; i < numArgs; i++) {
2501 if (getArgs()[i].getType() != arguments[i])
2502 return emitOpError() << "expected argument " << i << " to be of type "
2503 << arguments[i];
2504 }
2505 Type dstType = LLVM::LLVMStructType::getLiteral(
2506 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
2507 if (getType() != dstType)
2508 return emitOpError("expected destination type is a structure of ")
2509 << typeInfoC.second << " elements of type " << typeInfoC.first;
2510 return success();
2511}
2512
2513LogicalResult NVVM::LdMatrixOp::verify() {
2514 uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
2515 if (m == 8 && n == 8) {
2516 if (num != 1 && num != 2 && num != 4) {
2517 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
2518 "matrix");
2519 }
2520 if (getEltType() != LdStMatrixEltType::B16) {
2521 return emitOpError("expected element type to be b16 for 8x8 matrix");
2522 }
2523 } else if (m == 8 && n == 16) {
2524 if (num != 1 && num != 2 && num != 4) {
2525 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
2526 "matrix");
2527 }
2528 if (getLayout() != MMALayout::row) {
2529 return emitOpError("expected layout to be row for 8x16 matrix");
2530 }
2531 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2532 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2533 return emitOpError("expected element type to be b8x16.b4x16_p64 or "
2534 "b8x16.b6x16_p32 for 8x16 matrix");
2535 }
2536 } else if (m == 16 && n == 16) {
2537 if (num != 1 && num != 2) {
2538 return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
2539 "matrix");
2540 }
2541 if (getLayout() != MMALayout::col) {
2542 return emitOpError("expected layout to be col for 16x16 matrix");
2543 }
2544 if (getEltType() != LdStMatrixEltType::B8 &&
2545 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2546 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2547 return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
2548 "b8x16.b6x16_p32 for 16x16 matrix");
2549 }
2550 } else {
2551 return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
2552 }
2553
2554 Type i32 = IntegerType::get(getContext(), 32);
2555 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2556 if (numElements == 1 && getType() != i32)
2557 return emitOpError("expected destination type is i32");
2558 if (numElements == 2 || numElements == 4) {
2559 Type dstType = LLVM::LLVMStructType::getLiteral(
2560 getContext(), SmallVector<Type>(numElements, i32));
2561 if (getType() != dstType)
2562 return emitOpError("expected destination type is a structure of ")
2563 << numElements << " elements of type i32";
2564 }
2565
2566 return success();
2567}
2568
2569LogicalResult LdMatrixOp::inferReturnTypes(
2570 MLIRContext *context, std::optional<Location> location,
2571 LdMatrixOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
2572 uint32_t num = adaptor.getNum();
2573 uint32_t m = adaptor.getShape().getM();
2574 uint32_t n = adaptor.getShape().getN();
2575 uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
2576
2577 Type i32 = IntegerType::get(context, 32);
2578 if (numElements == 1)
2579 inferredReturnTypes.push_back(i32);
2580 else
2581 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2582 context, SmallVector<Type>(numElements, i32)));
2583 return success();
2584}
2585
2586LogicalResult NVVM::StMatrixOp::verify() {
2587 int numMatrix = getSources().size();
2588 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2589 return emitOpError("expected num attribute to be 1, 2 or 4");
2590
2591 int m = getShape().getM(), n = getShape().getN();
2592 if (m == 8 && n == 8) {
2593 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2594 return emitOpError("expected element type to be B16 for 8x8 matrix");
2595 }
2596 } else if (m == 16 && n == 8) {
2597 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2598 return emitOpError("expected element type to be B8 for 16x8 matrix");
2599 }
2600 if (getLayout() != NVVM::MMALayout::col) {
2601 return emitOpError("expected layout to be col for 16x8 matrix");
2602 }
2603 } else {
2604 return emitOpError("expected shape to be 8x8 or 16x8");
2605 }
2606
2607 return success();
2608}
2609
2610LogicalResult NVVM::MovMatrixOp::verify() {
2611 int m = getShape().getM(), n = getShape().getN();
2612 if (m != 8 || n != 8)
2613 return emitOpError("expected shape to be 8x8");
2614 if (getLayout() != NVVM::MMALayout::col)
2615 return emitOpError("expected layout to be col");
2616 if (getEltType() != NVVM::LdStMatrixEltType::B16)
2617 return emitOpError("expected element type to be b16");
2618 return success();
2619}
2620
2621static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
2622 if (typeA == NVVM::WGMMATypes::tf32)
2623 return 8;
2624 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2625 return 16;
2626 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2627 return 32;
2628 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2629 return 32;
2630 if (typeA == NVVM::WGMMATypes::b1)
2631 return 256;
2632 return failure();
2633}
2634
2635static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
2636 NVVM::WGMMATypes typeA,
2637 NVVM::WGMMATypes typeB) {
2638 switch (typeA) {
2639 case NVVM::WGMMATypes::f16:
2640 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2641 typeB == NVVM::WGMMATypes::f16)
2642 return success();
2643 break;
2644 case NVVM::WGMMATypes::tf32:
2645 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2646 return success();
2647 break;
2648 case NVVM::WGMMATypes::u8:
2649 case NVVM::WGMMATypes::s8:
2650 if (typeD == NVVM::WGMMATypes::s32 &&
2651 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2652 return success();
2653 break;
2654 case NVVM::WGMMATypes::b1:
2655 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2656 return success();
2657 break;
2658 case NVVM::WGMMATypes::bf16:
2659 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2660 typeB == NVVM::WGMMATypes::bf16)
2661 return success();
2662 break;
2663 case NVVM::WGMMATypes::e4m3:
2664 case NVVM::WGMMATypes::e5m2:
2665 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2666 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2667 return success();
2668 break;
2669 case WGMMATypes::f32:
2670 case WGMMATypes::s32:
2671 llvm_unreachable("unsupported input types");
2672 break;
2673 }
2674 return failure();
2675}
2676
2677static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
2678 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
2679 72, 80, 88, 96, 104, 112, 120, 128,
2680 136, 144, 152, 160, 168, 176, 184, 192,
2681 200, 208, 216, 224, 232, 240, 248, 256};
2682 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
2683 80, 96, 112, 128, 144, 160,
2684 176, 192, 208, 224, 240, 256};
2685 switch (typeA) {
2686 case WGMMATypes::f16:
2687 case WGMMATypes::tf32:
2688 case WGMMATypes::bf16:
2689 case WGMMATypes::e4m3:
2690 case WGMMATypes::e5m2:
2691 if (llvm::is_contained(allowedN, sizeN))
2692 return success();
2693 break;
2694 case WGMMATypes::u8:
2695 case WGMMATypes::s8:
2696 case WGMMATypes::b1:
2697 if (llvm::is_contained(allowedNshort, sizeN))
2698 return success();
2699 break;
2700 case WGMMATypes::f32:
2701 case WGMMATypes::s32:
2702 llvm_unreachable("unsupported input types");
2703 break;
2704 }
2705 return failure();
2706}
2707
2708LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2709 Value outValue = getResults();
2710 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
2711 if (!stype)
2712 return emitOpError() << "expected results to be struct";
2713 int outputSize = stype.getBody().size();
2714 WGMMATypes typeD = getTypeD();
2715 WGMMATypes typeA = getTypeA();
2716 WGMMATypes typeB = getTypeB();
2717
2718 for (Type t : stype.getBody()) {
2719 if (t != stype.getBody().front())
2720 return emitOpError()
2721 << "all elements in struct must be same type but there is " << t;
2722 }
2723
2724 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2725 typeD != WGMMATypes::s32) {
2726 return emitOpError() << "does not support the given output type " << typeD;
2727 }
2728 if (typeD == WGMMATypes::s32 &&
2729 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2730 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
2731 }
2732
2733 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
2734 return emitOpError() << typeD << " += " << typeA << " * " << typeB
2735 << ", it is not supported.";
2736 }
2737
2738 // Check M
2739 if (getShape().getM() != 64)
2740 return emitOpError() << "shape 'm' must be 64";
2741
2742 // Check K
2743 FailureOr<int> allowedK = getAllowedSizeK(typeA);
2744 if (failed(allowedK) || allowedK.value() != getShape().getK())
2745 return emitOpError() << "shape 'k' must be " << allowedK.value()
2746 << " for input type " << typeA;
2747
2748 // Check N
2749 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
2750 return emitOpError() << "has input type " << typeA << " n is set to "
2751 << getShape().getN() << ", it is not supported.";
2752 }
2753
2754 // Check transpose (only available for f16/bf16)
2755 // Matrices A should be stored in row-major and B in column-major.
2756 // Only f16/bf16 matrices can be stored in either column-major or row-major
2757 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
2758 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2759 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2760 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2761 return emitOpError()
2762 << "given layouts layout_a = " << getLayoutA()
2763 << " and layout_b = " << getLayoutB() << " for input types " << typeA
2764 << " and " << typeB
2765 << " requires transpose. However, this is only supported for: "
2766 << MMATypes::f16 << " and " << MMATypes::bf16;
2767 }
2768
2769 // Check result registers
2770 int expectedOutput = 0;
2771 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2772 expectedOutput = getShape().getN() / 2;
2773 if (typeD == WGMMATypes::f16)
2774 expectedOutput = getShape().getN() / 4;
2775 if (outputSize != expectedOutput) {
2776 return emitOpError() << "results " << expectedOutput
2777 << ", however output struct has " << outputSize
2778 << " elements";
2779 }
2780 // Check satfinite (only available for s32 accumulator)
2781 if (typeD != WGMMATypes::s32 &&
2782 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2783 NVVM::MMAIntOverflow::satfinite) {
2784 return emitOpError()
2785 << " `satfinite` can be only used with s32 accumulator, however "
2786 "the current accumulator is "
2787 << typeD;
2788 }
2789
2790 return success();
2791}
2792
2793std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2794
2795 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
2796 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2797
2798 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2799
2800 int expectedOutputRegisters = 0;
2801 if (getTypeD() == WGMMATypes::f16)
2802 expectedOutputRegisters = getShape().getN() / 4;
2803 else
2804 expectedOutputRegisters = getShape().getN() / 2;
2805
2806 std::string ptx;
2807 llvm::raw_string_ostream ss(ptx);
2808
2809 ss << "{\n"
2810 ".reg .pred p;\n"
2811 "setp.ne.b32 p, $"
2812 << ((expectedOutputRegisters * 2) + 2)
2813 << ", 0;\n"
2814 "wgmma.mma_async.sync.aligned.m"
2815 << m << "n" << n << "k" << k << "." << outputTypeName << "." << getTypeA()
2816 << "." << getTypeB();
2817 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2818 NVVM::MMAIntOverflow::satfinite)
2819 ss << ".satfinite";
2820 ss << " {";
2821 int regCnt = 0;
2822 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2823 ss << "$" << regCnt;
2824 if (regCnt != expectedOutputRegisters - 1)
2825 ss << ", ";
2826 }
2827
2828 ss << "},";
2829 // Need to map read/write registers correctly.
2830 regCnt = (regCnt * 2);
2831 ss << " $" << (regCnt) << ","
2832 << " $" << (regCnt + 1) << ","
2833 << " p";
2834 if (getTypeD() != WGMMATypes::s32) {
2835 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
2836 }
2837 // Don't add transpose parameters unless needed.
2838 if (isF16) {
2839 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
2840 }
2841 ss << ";\n"
2842 << "}\n";
2843 return ptx;
2844}
2845
2846bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2847 RewriterBase &rewriter,
2848 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2849 &asmValues) {
2850 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2851 if (getResults())
2852 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
2853 if (getInouts())
2854 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
2855 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
2856 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
2857 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
2859 if (getTypeD() != WGMMATypes::s32) {
2860 asmValues.push_back(
2861 {makeConstantI32(rewriter,
2862 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2864 asmValues.push_back(
2865 {makeConstantI32(rewriter,
2866 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2868 }
2869 if (isF16) {
2870 asmValues.push_back(
2871 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
2873 asmValues.push_back(
2874 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
2876 }
2877 return true; // Has manual mapping
2878}
2879
2880LogicalResult NVVM::FenceProxyOp::verify() {
2881 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2882 return emitOpError() << "async_shared fence requires space attribute";
2883 }
2884 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2885 return emitOpError() << "only async_shared fence can have space attribute";
2886 }
2887 return success();
2888}
2889
2890LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2891 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2892 return emitOpError("uni-directional proxies only support generic for "
2893 "from_proxy attribute");
2894
2895 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2896 return emitOpError("uni-directional proxies only support tensormap "
2897 "for to_proxy attribute");
2898 return success();
2899}
2900
2901LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2902 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2903 return emitOpError("uni-directional proxies only support generic for "
2904 "from_proxy attribute");
2905
2906 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2907 return emitOpError("uni-directional proxies only support tensormap "
2908 "for to_proxy attribute");
2909 return success();
2910}
2911
2912LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2913 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2914 return emitOpError("only generic is support for from_proxy attribute");
2915
2916 if (getToProxy() != NVVM::ProxyKind::async)
2917 return emitOpError("only async is supported for to_proxy attribute");
2918 return success();
2919}
2920
2921LogicalResult NVVM::SetMaxRegisterOp::verify() {
2922 if (getRegCount() % 8)
2923 return emitOpError("new register size must be multiple of 8");
2924 if (getRegCount() < 24 || getRegCount() > 256)
2925 return emitOpError("new register size must be in between 24 to 256");
2926 return success();
2927}
2928
2929LogicalResult NVVM::Tcgen05CpOp::verify() {
2930 auto mc = getMulticast();
2931
2932 using SH = Tcgen05CpShape;
2933 using MC = Tcgen05CpMulticast;
2934 switch (getShape()) {
2935 case SH::SHAPE_128x256b:
2936 case SH::SHAPE_128x128b:
2937 case SH::SHAPE_4x256b:
2938 if (mc != MC::NONE)
2939 return emitError("Invalid multicast type for tcgen05.cp Op");
2940 break;
2941 case SH::SHAPE_64x128b:
2942 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2943 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
2944 "warpx2_02_13 for tcgen05.cp Op");
2945 break;
2946 case SH::SHAPE_32x128b:
2947 if (mc != MC::WARPX4)
2948 return emitError(
2949 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2950 break;
2951 }
2952 return success();
2953}
2954
2955LogicalResult NVVM::MatchSyncOp::verify() {
2956 if (getKind() == NVVM::MatchSyncKind::all) {
2957 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2958 if (!type || type.getBody().size() != 2 ||
2959 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2960 return emitOpError("match.sync 'all' returns a two element struct with "
2961 "first element as i32 and second element as i1");
2962 }
2963 } else {
2964 if (!getType().isInteger(32)) {
2965 return emitOpError("match.sync 'any' returns an i32");
2966 }
2967 }
2968 return success();
2969}
2970
2971LogicalResult MatchSyncOp::inferReturnTypes(
2972 MLIRContext *context, std::optional<Location> location,
2973 MatchSyncOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
2974 if (adaptor.getKind() == NVVM::MatchSyncKind::all)
2975 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2976 context,
2977 {IntegerType::get(context, 32), IntegerType::get(context, 1)}));
2978 else
2979 inferredReturnTypes.push_back(IntegerType::get(context, 32));
2980 return success();
2981}
2982
2983LogicalResult NVVM::VoteSyncOp::verify() {
2984 if (getKind() == NVVM::VoteSyncKind::ballot) {
2985 if (!getType().isInteger(32)) {
2986 return emitOpError("vote.sync 'ballot' returns an i32");
2987 }
2988 } else {
2989 if (!getType().isInteger(1)) {
2990 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
2991 }
2992 }
2993 return success();
2994}
2995
2996LogicalResult VoteSyncOp::inferReturnTypes(
2997 MLIRContext *context, std::optional<Location> location,
2998 VoteSyncOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
2999 unsigned width = adaptor.getKind() == NVVM::VoteSyncKind::ballot ? 32 : 1;
3000 inferredReturnTypes.push_back(IntegerType::get(context, width));
3001 return success();
3002}
3003
3004LogicalResult NVVM::PrefetchOp::verify() {
3005 using MemSpace = NVVM::NVVMMemorySpace;
3006 using CacheLevel = NVVM::PrefetchCacheLevel;
3007
3008 unsigned addressSpace =
3009 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
3010 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
3011 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
3012
3013 if (getTensormap() && cacheLevel)
3014 return emitOpError("cannot specify both tensormap and cache level");
3015
3016 if (getTensormap()) {
3017 if (addressSpace != MemSpace::Generic &&
3018 addressSpace != MemSpace::Constant) {
3019 return emitOpError(
3020 "prefetch tensormap requires a generic or constant pointer");
3021 }
3022
3023 if (evictPriority) {
3024 return emitOpError(
3025 "prefetch tensormap does not support eviction priority");
3026 }
3027
3028 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
3029 return emitOpError(
3030 "in_param_space can only be specified for a generic pointer");
3031 }
3032
3033 } else if (cacheLevel) {
3034 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
3035 addressSpace != MemSpace::Local) {
3036 return emitOpError("prefetch to cache level requires a generic, global, "
3037 "or local pointer");
3038 }
3039
3040 if (getUniform()) {
3041 if (*cacheLevel != CacheLevel::L1) {
3042 return emitOpError(
3043 "unsupported cache level, the only supported uniform "
3044 "cache level is L1");
3045 }
3046
3047 if (addressSpace != MemSpace::Generic) {
3048 return emitOpError(
3049 "prefetch to uniform cache requires a generic pointer");
3050 }
3051 }
3052
3053 if (evictPriority) {
3054 if (*cacheLevel != CacheLevel::L2)
3055 return emitOpError(
3056 "cache eviction priority supported only for cache level L2");
3057
3058 if (addressSpace != MemSpace::Global)
3059 return emitOpError("cache eviction priority requires a global pointer");
3060
3061 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
3062 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
3063 return emitOpError(
3064 "unsupported cache eviction priority, only evict_last and "
3065 "evict_normal are supported");
3066 }
3067
3068 if (getPredicate())
3069 return emitOpError("predicate supported only on prefetch tensormap");
3070
3071 } else {
3072 return emitOpError(
3073 "requires specification of either cache level or tensormap");
3074 }
3075
3076 return success();
3077}
3078
3079LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
3080 switch (getQueryType()) {
3081 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
3082 if (!getType().isInteger(1))
3083 return emitOpError("is_canceled query type returns an i1");
3084 break;
3085 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
3086 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
3087 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
3088 if (!getType().isInteger(32)) {
3089 return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
3090 "get_first_cta_id_z query types return an i32");
3091 }
3092 break;
3093 }
3094 return success();
3095}
3096
3097LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
3098 MLIRContext *context, std::optional<Location> location,
3099 ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
3100 SmallVectorImpl<Type> &inferredReturnTypes) {
3101 unsigned width =
3102 adaptor.getQueryType() == NVVM::ClusterLaunchControlQueryType::IS_CANCELED
3103 ? 1
3104 : 32;
3105 inferredReturnTypes.push_back(IntegerType::get(context, width));
3106 return success();
3107}
3108
3109LogicalResult NVVM::ReduxOp::verify() {
3110 mlir::Type reduxType = getType();
3111
3112 if (!reduxType.isF32()) {
3113 if (getAbs())
3114 return emitOpError("abs attribute is supported only for f32 type");
3115 if (getNan())
3116 return emitOpError("nan attribute is supported only for f32 type");
3117 }
3118
3119 NVVM::ReductionKind kind = getKind();
3120 switch (kind) {
3121 case NVVM::ReductionKind::ADD:
3122 case NVVM::ReductionKind::AND:
3123 case NVVM::ReductionKind::OR:
3124 case NVVM::ReductionKind::XOR:
3125 case NVVM::ReductionKind::MAX:
3126 case NVVM::ReductionKind::MIN:
3127 case NVVM::ReductionKind::UMAX:
3128 case NVVM::ReductionKind::UMIN:
3129 if (!reduxType.isInteger(32))
3130 return emitOpError("'")
3131 << kind << "' reduction kind unsupported with " << reduxType
3132 << " type. Only supported type is 'i32'.";
3133 break;
3134 case NVVM::ReductionKind::FMIN:
3135 case NVVM::ReductionKind::FMAX:
3136 if (!reduxType.isF32())
3137 return emitOpError("'")
3138 << kind << "' reduction kind unsupported with " << reduxType
3139 << " type. Only supported type is 'f32'.";
3140 break;
3141 }
3142
3143 return success();
3144}
3145
3146LogicalResult NVVM::TensormapReplaceOp::verify() {
3147 auto ord = getOrd();
3148 Value newVal = getNewValue();
3149 auto newValAttr = getNewValueAttr();
3150 auto fieldName = stringifyEnum(getField());
3151
3152 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
3153 NVVM::TensormapField::GLOBAL_DIM,
3154 NVVM::TensormapField::GLOBAL_STRIDE,
3155 NVVM::TensormapField::ELEMENT_STRIDE},
3156 getField()))
3157 return emitOpError("ordinal is not supported for ")
3158 << fieldName << " field";
3159
3160 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3161 return llvm::Twine("new_value must be specified and must be an " + type +
3162 " for " + llvm::Twine(fieldName) + " field")
3163 .str();
3164 };
3165
3166 auto invalidNewValAttr = [&]() -> std::string {
3167 return (llvm::Twine(
3168 "new_value_attr must be specified and must be a valid ") +
3169 llvm::Twine(fieldName) + " attribute for " + fieldName + " field")
3170 .str();
3171 };
3172
3173 switch (getField()) {
3174 case NVVM::TensormapField::GLOBAL_ADDRESS:
3175 if (!(newVal && newVal.getType().isInteger(64)))
3176 return emitOpError(invalidNewVal("i64"));
3177 break;
3178 case NVVM::TensormapField::RANK:
3179 if (!(newVal && newVal.getType().isInteger(32)))
3180 return emitOpError(invalidNewVal("i32"));
3181 break;
3182 case NVVM::TensormapField::GLOBAL_STRIDE:
3183 if (!ord)
3184 return emitOpError("ordinal is required for global_stride field");
3185 if (!(newVal && newVal.getType().isInteger(64)))
3186 return emitOpError(invalidNewVal("i64"));
3187 break;
3188 case NVVM::TensormapField::BOX_DIM:
3189 case NVVM::TensormapField::GLOBAL_DIM:
3190 case NVVM::TensormapField::ELEMENT_STRIDE:
3191 if (!ord)
3192 return emitOpError("ordinal is required for ")
3193 << stringifyEnum(getField()) << " field";
3194 if (!(newVal && newVal.getType().isInteger(32)))
3195 return emitOpError(invalidNewVal("i32"));
3196 break;
3197 case NVVM::TensormapField::ELEMTYPE:
3198 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3199 return emitOpError(invalidNewValAttr());
3200 break;
3201 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3202 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3203 return emitOpError(invalidNewValAttr());
3204 break;
3205 case NVVM::TensormapField::SWIZZLE_MODE:
3206 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3207 return emitOpError(invalidNewValAttr());
3208 break;
3209 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3210 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3211 return emitOpError(invalidNewValAttr());
3212 break;
3213 case NVVM::TensormapField::FILL_MODE:
3214 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3215 return emitOpError(invalidNewValAttr());
3216 break;
3217 }
3218
3219 return success();
3220}
3221
3222template <typename OpType>
3223static LogicalResult verifyAddSubFOp(OpType op) {
3224 mlir::NVVM::FPRoundingMode rndMode = op.getRnd();
3225 mlir::NVVM::SaturationMode satMode = op.getSat();
3226 bool isFTZ = op.getFtz();
3227
3228 mlir::Type opType = op.getRes().getType();
3229 mlir::Type opBaseType = isa<VectorType>(opType)
3230 ? cast<VectorType>(opType).getElementType()
3231 : opType;
3232
3233 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3234 return op.emitOpError("FTZ and saturation are not supported for "
3235 "additions/subtractions involving f64 type");
3236
3237 if (opBaseType.isF16() && !(rndMode == NVVM::FPRoundingMode::RN ||
3238 rndMode == NVVM::FPRoundingMode::NONE))
3239 return op.emitOpError("only RN rounding mode is supported for f16 and "
3240 "vector<2xf16> additions/subtractions");
3241
3242 if (opBaseType.isBF16()) {
3243 if (rndMode != NVVM::FPRoundingMode::RN &&
3244 rndMode != NVVM::FPRoundingMode::NONE)
3245 return op.emitOpError("only RN rounding mode is supported for bf16 and "
3246 "vector<2xbf16> additions/subtractions");
3247 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3248 return op.emitOpError("FTZ and saturation are not supported for bf16 and "
3249 "vector<2xbf16> additions/subtractions");
3250 }
3251
3252 // FIXME: This is a temporary check disallowing lowering to add.rn.ftz.f16(x2)
3253 // PTX instructions since the corresponding LLVM intrinsic is missing. This
3254 // should be removed once the intrinsics for f16 addition (with FTZ only) are
3255 // available.
3256 if (opBaseType.isF16() && isFTZ && satMode == NVVM::SaturationMode::NONE)
3257 return op.emitOpError("FTZ with no saturation is not supported for f16 and "
3258 "vector<2xf16> additions/subtractions");
3259
3260 return success();
3261}
3262
3263LogicalResult NVVM::AddFOp::verify() { return verifyAddSubFOp<AddFOp>(*this); }
3264
3265LogicalResult NVVM::SubFOp::verify() { return verifyAddSubFOp<SubFOp>(*this); }
3266
3267LogicalResult NVVM::FmaOp::verify() {
3268 auto opType = getRes().getType();
3269 mlir::NVVM::FPRoundingMode rndMode = getRnd();
3270 mlir::NVVM::SaturationMode satMode = getSat();
3271 bool isFTZ = getFtz();
3272 bool isRelu = getRelu();
3273 bool hasOOB = getOob();
3274
3275 auto getBaseFType = [](Type type) -> Type {
3276 if (isa<VectorType>(type))
3277 return cast<VectorType>(type).getElementType();
3278 return type;
3279 };
3280
3281 auto opBaseType = getBaseFType(opType);
3282
3283 if (rndMode == NVVM::FPRoundingMode::NONE)
3284 return emitOpError("rounding mode must be specified");
3285
3286 if (isRelu && satMode == NVVM::SaturationMode::SAT)
3287 return emitOpError("relu and saturation are not supported together");
3288
3289 if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
3290 return emitOpError("oob is not supported with saturation or FTZ");
3291
3292 if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
3293 return emitOpError("relu and oob are only supported for f16 and bf16");
3294
3295 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3296 return emitOpError("FTZ and saturation are not supported for f64 type");
3297
3298 if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
3299 return emitOpError(
3300 "only RN rounding mode is supported for f16 and vector<2xf16>");
3301
3302 if (opBaseType.isBF16()) {
3303 if (rndMode != NVVM::FPRoundingMode::RN)
3304 return emitOpError(
3305 "only RN rounding mode is supported for bf16 and vector<2xbf16>");
3306 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3307 return emitOpError(
3308 "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
3309 }
3310
3311 return success();
3312}
3313
3314LogicalResult NVVM::SqrtOp::verify() {
3315 if (getRnd() == NVVM::FPRoundingMode::NONE)
3316 return emitOpError("rounding mode cannot be None");
3317
3318 if (getRes().getType().isF64() && getFtz())
3319 return emitOpError("FTZ is not supported for f64");
3320
3321 return success();
3322}
3323
3324LogicalResult NVVM::DivFOp::verify() {
3325 bool isApprox = getApprox();
3326 bool isFull = getFull();
3327 bool isF64 = getRes().getType().isF64();
3328 bool isFtz = getFtz();
3329 NVVM::FPRoundingMode rndMode = getRnd();
3330
3331 if (isApprox && isFull)
3332 return emitOpError("'approx' and 'full' are mutually exclusive");
3333
3334 if (isApprox || isFull) {
3335 if (isF64)
3336 return emitOpError("'approx' and 'full' forms are f32-only");
3337 if (rndMode != NVVM::FPRoundingMode::NONE)
3338 return emitOpError(
3339 "'approx' and 'full' forms do not accept a rounding mode");
3340 return success();
3341 }
3342
3343 // Rounded form below.
3344 if (rndMode == NVVM::FPRoundingMode::NONE)
3345 return emitOpError("rounding mode cannot be None for the rounded divide");
3346 if (isF64 && isFtz)
3347 return emitOpError("FTZ is not supported for f64");
3348
3349 return success();
3350}
3351
3352/// Packs the given `field` into the `result`.
3353/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
3354static llvm::Value *
3355packValInto64Bits(llvm::IRBuilderBase &builder,
3356 llvm::Value *result, // the `result` (unset bits are zero)
3357 llvm::Value *field, // `field` to pack into `result`
3358 unsigned sizeInBits, // Size of `field` in bits
3359 unsigned start) { // Starting bit within `result`
3360 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3361
3362 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3363 if (mask != 0xffffffffu)
3364 field = builder.CreateAnd(field, builder.getInt32(mask));
3365
3366 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3367 field = builder.CreateShl(field, start);
3368
3369 return builder.CreateOr(result, field);
3370}
3371
3372void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
3374 llvm::IRBuilderBase &builder) {
3375 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3376 llvm::Value *smemDesc = builder.getInt64(0);
3377
3378 smemDesc = packValInto64Bits(builder, smemDesc,
3379 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
3380 smemDesc = packValInto64Bits(
3381 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3382 smemDesc = packValInto64Bits(
3383 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3384
3385 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
3386 smemDesc = packValInto64Bits(builder, smemDesc,
3387 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
3388 smemDesc = packValInto64Bits(
3389 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3390 smemDesc = packValInto64Bits(builder, smemDesc,
3391 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
3392
3393 mt.mapValue(thisOp.getRes()) = smemDesc;
3394}
3395
3396//===----------------------------------------------------------------------===//
3397// getPtx methods
3398//===----------------------------------------------------------------------===//
3399
3400std::string NVVM::MBarrierInitOp::getPtx() {
3401 bool isShared = isPtrInSharedCTASpace(getAddr());
3402 return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
3403 : std::string("mbarrier.init.b64 [%0], %1;");
3404}
3405
3406std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3407 bool isShared = isPtrInSharedCTASpace(getAddr());
3408 return isShared
3409 ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3410 : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3411}
3412
3413std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3414 bool isShared = isPtrInSharedCTASpace(getAddr());
3415 llvm::StringRef space = isShared ? ".shared" : "";
3416
3417 return llvm::formatv("{\n\t"
3418 ".reg .pred P1; \n\t"
3419 "LAB_WAIT: \n\t"
3420 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3421 "@P1 bra.uni DONE; \n\t"
3422 "bra.uni LAB_WAIT; \n\t"
3423 "DONE: \n\t"
3424 "}",
3425 space);
3426}
3427
3428//===----------------------------------------------------------------------===//
3429// Canonicalization patterns
3430//===----------------------------------------------------------------------===//
3431
3434
3435 LogicalResult matchAndRewrite(SubFOp op,
3436 PatternRewriter &rewriter) const override {
3437 Location loc = op.getLoc();
3438 Value negRhs =
3439 LLVM::FNegOp::create(rewriter, loc, op.getRhs().getType(), op.getRhs());
3440
3441 rewriter.replaceOpWithNewOp<AddFOp>(op, op.getType(), op.getLhs(), negRhs,
3442 op.getRnd(), op.getSat(), op.getFtz());
3443 return success();
3444 }
3445};
3446
3447void SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3448 MLIRContext *context) {
3449 patterns.add<ConvertFsubToFnegFadd>(context);
3450}
3451
3452//===----------------------------------------------------------------------===//
3453// getIntrinsicID/getIntrinsicIDAndArgs methods
3454//===----------------------------------------------------------------------===//
3455
3456mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
3457 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3458 auto thisOp = cast<NVVM::BarrierOp>(op);
3459 llvm::Value *barrierId = thisOp.getBarrierId()
3460 ? mt.lookupValue(thisOp.getBarrierId())
3461 : builder.getInt32(0);
3462 llvm::Intrinsic::ID id;
3463 llvm::SmallVector<llvm::Value *> args = {barrierId};
3464 if (thisOp.getNumberOfThreads()) {
3465 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3466 args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
3467 } else {
3468 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3469 }
3470 return {id, std::move(args)};
3471}
3472
3473mlir::NVVM::IDArgPair NVVM::BarrierReductionOp::getIntrinsicIDAndArgs(
3474 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3475 auto thisOp = cast<NVVM::BarrierReductionOp>(op);
3476 llvm::Intrinsic::ID id;
3477 switch (thisOp.getReductionOp()) {
3478 case NVVM::BarrierReduction::AND:
3479 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3480 break;
3481 case NVVM::BarrierReduction::OR:
3482 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3483 break;
3484 case NVVM::BarrierReduction::POPC:
3485 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3486 break;
3487 }
3488 llvm::Value *barrierId = thisOp.getBarrierId()
3489 ? mt.lookupValue(thisOp.getBarrierId())
3490 : builder.getInt32(0);
3492 barrierId,
3493 builder.CreateICmpNE(mt.lookupValue(thisOp.getReductionPredicate()),
3494 builder.getInt32(0))};
3495 return {id, std::move(args)};
3496}
3497
3499CosOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3500 llvm::IRBuilderBase &builder) {
3501 auto thisOp = cast<NVVM::CosOp>(op);
3502 llvm::Intrinsic::ID id = thisOp.getFtz()
3503 ? llvm::Intrinsic::nvvm_cos_approx_ftz_f
3504 : llvm::Intrinsic::nvvm_cos_approx_f;
3505 return {id, {mt.lookupValue(thisOp.getSrc())}};
3506}
3507
3509SinOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3510 llvm::IRBuilderBase &builder) {
3511 auto thisOp = cast<NVVM::SinOp>(op);
3512 llvm::Intrinsic::ID id = thisOp.getFtz()
3513 ? llvm::Intrinsic::nvvm_sin_approx_ftz_f
3514 : llvm::Intrinsic::nvvm_sin_approx_f;
3515 return {id, {mt.lookupValue(thisOp.getSrc())}};
3516}
3517
3519Log2Op::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3520 llvm::IRBuilderBase &builder) {
3521 auto thisOp = cast<NVVM::Log2Op>(op);
3522 llvm::Intrinsic::ID id = thisOp.getFtz()
3523 ? llvm::Intrinsic::nvvm_lg2_approx_ftz_f
3524 : llvm::Intrinsic::nvvm_lg2_approx_f;
3525 return {id, {mt.lookupValue(thisOp.getSrc())}};
3526}
3527
3529Ex2Op::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3530 llvm::IRBuilderBase &builder) {
3531 auto thisOp = cast<NVVM::Ex2Op>(op);
3532 llvm::Intrinsic::ID id = thisOp.getFtz()
3533 ? llvm::Intrinsic::nvvm_ex2_approx_ftz
3534 : llvm::Intrinsic::nvvm_ex2_approx;
3535 return {id, {mt.lookupValue(thisOp.getSrc())}};
3536}
3537
3539RsqrtOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3540 llvm::IRBuilderBase &builder) {
3541 auto thisOp = cast<NVVM::RsqrtOp>(op);
3542 Type t = thisOp.getRes().getType();
3543 bool isFtz = thisOp.getFtz();
3544
3545 llvm::Intrinsic::ID id = [&] {
3546 if (t.isF32()) {
3547 return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_f
3548 : llvm::Intrinsic::nvvm_rsqrt_approx_f;
3549 }
3550 // f64
3551 return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_d
3552 : llvm::Intrinsic::nvvm_rsqrt_approx_d;
3553 }();
3554
3555 return {id, {mt.lookupValue(thisOp.getSrc())}};
3556}
3557
3559SqrtOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3560 llvm::IRBuilderBase &builder) {
3561 auto thisOp = cast<NVVM::SqrtOp>(op);
3562 Type t = thisOp.getRes().getType();
3563 NVVM::FPRoundingMode rndMode = thisOp.getRnd();
3564 bool isFtz = thisOp.getFtz();
3565
3566 // RM is one of RN/RM/RP/RZ (verifier rejects NONE).
3567 // Subtracting 1 maps RN=1..RZ=4 to 0..3.
3568 unsigned rndIndex = static_cast<unsigned>(rndMode) - 1;
3569
3570 static constexpr llvm::Intrinsic::ID f32IDs[] = {
3571 llvm::Intrinsic::nvvm_sqrt_rn_f,
3572 llvm::Intrinsic::nvvm_sqrt_rm_f,
3573 llvm::Intrinsic::nvvm_sqrt_rp_f,
3574 llvm::Intrinsic::nvvm_sqrt_rz_f,
3575 };
3576 static constexpr llvm::Intrinsic::ID f32FTZIDs[] = {
3577 llvm::Intrinsic::nvvm_sqrt_rn_ftz_f,
3578 llvm::Intrinsic::nvvm_sqrt_rm_ftz_f,
3579 llvm::Intrinsic::nvvm_sqrt_rp_ftz_f,
3580 llvm::Intrinsic::nvvm_sqrt_rz_ftz_f,
3581 };
3582 static constexpr llvm::Intrinsic::ID f64IDs[] = {
3583 llvm::Intrinsic::nvvm_sqrt_rn_d,
3584 llvm::Intrinsic::nvvm_sqrt_rm_d,
3585 llvm::Intrinsic::nvvm_sqrt_rp_d,
3586 llvm::Intrinsic::nvvm_sqrt_rz_d,
3587 };
3588
3589 llvm::Intrinsic::ID id =
3590 t.isF32() ? (isFtz ? f32FTZIDs[rndIndex] : f32IDs[rndIndex])
3591 : f64IDs[rndIndex];
3592
3593 return {id, {mt.lookupValue(thisOp.getSrc())}};
3594}
3595
3597SqrtApproxOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3598 llvm::IRBuilderBase &builder) {
3599 auto thisOp = cast<NVVM::SqrtApproxOp>(op);
3600 llvm::Intrinsic::ID id = thisOp.getFtz()
3601 ? llvm::Intrinsic::nvvm_sqrt_approx_ftz_f
3602 : llvm::Intrinsic::nvvm_sqrt_approx_f;
3603 return {id, {mt.lookupValue(thisOp.getSrc())}};
3604}
3605
3607DivFOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3608 llvm::IRBuilderBase &builder) {
3609 auto thisOp = cast<NVVM::DivFOp>(op);
3610 bool isFtz = thisOp.getFtz();
3611
3612 llvm::Intrinsic::ID id;
3613
3614 if (thisOp.getApprox()) {
3615 id = isFtz ? llvm::Intrinsic::nvvm_div_approx_ftz_f
3616 : llvm::Intrinsic::nvvm_div_approx_f;
3617 } else if (thisOp.getFull()) {
3618 // Intrinsic Naming quirk: int_nvvm_div_full has no `_f` suffix (unlike
3619 // approx).
3620 id = isFtz ? llvm::Intrinsic::nvvm_div_full_ftz
3621 : llvm::Intrinsic::nvvm_div_full;
3622 } else {
3623 // Rounded form — three 4-entry tables indexed by (rndMode - 1).
3624 unsigned rndIndex = static_cast<unsigned>(thisOp.getRnd()) - 1;
3625
3626 static constexpr llvm::Intrinsic::ID f32IDs[] = {
3627 llvm::Intrinsic::nvvm_div_rn_f,
3628 llvm::Intrinsic::nvvm_div_rm_f,
3629 llvm::Intrinsic::nvvm_div_rp_f,
3630 llvm::Intrinsic::nvvm_div_rz_f,
3631 };
3632 static constexpr llvm::Intrinsic::ID f32FTZIDs[] = {
3633 llvm::Intrinsic::nvvm_div_rn_ftz_f,
3634 llvm::Intrinsic::nvvm_div_rm_ftz_f,
3635 llvm::Intrinsic::nvvm_div_rp_ftz_f,
3636 llvm::Intrinsic::nvvm_div_rz_ftz_f,
3637 };
3638 static constexpr llvm::Intrinsic::ID f64IDs[] = {
3639 llvm::Intrinsic::nvvm_div_rn_d,
3640 llvm::Intrinsic::nvvm_div_rm_d,
3641 llvm::Intrinsic::nvvm_div_rp_d,
3642 llvm::Intrinsic::nvvm_div_rz_d,
3643 };
3644 Type t = thisOp.getRes().getType();
3645 id = t.isF32() ? (isFtz ? f32FTZIDs[rndIndex] : f32IDs[rndIndex])
3646 : f64IDs[rndIndex];
3647 }
3648
3649 return {id,
3650 {mt.lookupValue(thisOp.getLhs()), mt.lookupValue(thisOp.getRhs())}};
3651}
3652
3654PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3655 llvm::IRBuilderBase &builder) {
3656 auto thisOp = cast<NVVM::PMEventOp>(op);
3657 llvm::Type *i16Ty = llvm::Type::getInt16Ty(mt.getLLVMContext());
3658
3659 // With event-id, mask is generated as (1 << event-id)
3660 llvm::Value *maskVal;
3661 if (auto eventAttr = thisOp.getEventIdAttr()) {
3662 uint16_t mask = static_cast<uint16_t>(1u << eventAttr.getInt());
3663 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3664 } else {
3665 maskVal =
3666 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3667 }
3668
3669 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3670}
3671
3672mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
3673 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3674 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3675 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3676 llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3677 : llvm::Intrinsic::nvvm_mbarrier_init;
3678
3679 // Fill the Intrinsic Args
3681 args.push_back(mt.lookupValue(thisOp.getAddr()));
3682 args.push_back(mt.lookupValue(thisOp.getCount()));
3683
3684 return {id, std::move(args)};
3685}
3686
3687mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
3688 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3689 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3690 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3691 llvm::Intrinsic::ID id = isShared
3692 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3693 : llvm::Intrinsic::nvvm_mbarrier_inval;
3694
3695 return {id, {mt.lookupValue(thisOp.getAddr())}};
3696}
3697
3698mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
3699 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3700 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3701
3702 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3703 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3704 // bit-0: Space
3705 // bit-1: Scope
3706 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3707
3708 static constexpr llvm::Intrinsic::ID IDs[] = {
3709 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3710 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3711 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3712 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3713
3714 // Fill the Intrinsic Args
3716 args.push_back(mt.lookupValue(thisOp.getAddr()));
3717 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3718
3719 return {IDs[index], std::move(args)};
3720}
3721
3722mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
3723 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3724 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3725
3726 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3727 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3728 // bit-0: Space
3729 // bit-1: Scope
3730 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3731
3732 static constexpr llvm::Intrinsic::ID IDs[] = {
3733 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3734 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3735 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3736 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3737
3738 // Fill the Intrinsic Args
3740 args.push_back(mt.lookupValue(thisOp.getAddr()));
3741 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3742
3743 return {IDs[index], std::move(args)};
3744}
3745
3746mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
3747 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3748 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3749
3750 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3751 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3752 // bit-0: Space
3753 // bit-1: Scope
3754 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3755
3756 static constexpr llvm::Intrinsic::ID IDs[] = {
3757 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3758 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3759 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3760 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3761 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3762 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3763 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3764 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3765 llvm::Intrinsic::
3766 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3767 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3768
3769 // Tidy-up the Intrinsic Args
3770 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3771 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3772 if (needCast)
3773 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3774
3775 // We have the most basic mbarrier.arrive supported on sm_80.
3776 // It supports: Space=cta, scope=cta, No relaxed, No explicit count.
3777 // So, only for this combination use the legacy intrinsic.
3778 bool hasCount = static_cast<bool>(thisOp.getCount());
3779 if (!hasCount &&
3780 (id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3781 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3782
3783 // When count is not explicitly specified, the default is 1.
3784 llvm::LLVMContext &ctx = mt.getLLVMContext();
3785 llvm::Value *count =
3786 hasCount ? mt.lookupValue(thisOp.getCount())
3787 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3788 return {id, {mbar, count}};
3789}
3790
3791mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
3792 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3793 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3794
3795 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3796 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3797 // bit-0: Space
3798 // bit-1: Scope
3799 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3800
3801 static constexpr llvm::Intrinsic::ID IDs[] = {
3802 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3803 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3804 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3805 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3806 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3807 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3808 llvm::Intrinsic::
3809 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3810 llvm::Intrinsic::
3811 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3812 llvm::Intrinsic::
3813 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3814 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3815
3816 // Tidy-up the Intrinsic Args
3817 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3818 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3819 if (needCast)
3820 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3821
3822 // When count is not explicitly specified, the default is 1.
3823 llvm::LLVMContext &ctx = mt.getLLVMContext();
3824 bool hasCount = static_cast<bool>(thisOp.getCount());
3825 llvm::Value *count =
3826 hasCount ? mt.lookupValue(thisOp.getCount())
3827 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3828
3829 return {id, {mbar, count}};
3830}
3831
3832bool MBarrierArriveExpectTxOp::getAsmValues(
3833 RewriterBase &rewriter,
3834 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3835 &asmValues) {
3836 // Add all the operands but not the attrs to the asmValues list.
3837 // The attrs here are used to generate the right variants for
3838 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3839 for (auto val : getOperands())
3840 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3841
3842 return false;
3843}
3844
3845mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
3846 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3847 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3848
3849 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3850 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3851 // bit-0: Space
3852 // bit-1: Scope
3853 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3854
3855 // clang-format off
3856 static constexpr llvm::Intrinsic::ID IDs[] = {
3857 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3858 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3859 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3860 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3861 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3862 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3863 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3864 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3865 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3866 // clang-format on
3867 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3868
3869 // Tidy-up the Intrinsic Args
3870 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3871 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3872 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3873 if (needCast)
3874 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3875
3876 return {id, {mbar, txcount}};
3877}
3878
3879mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
3880 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3881 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3882
3883 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3884 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3885 // bit-0: Space
3886 // bit-1: Scope
3887 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3888
3889 // clang-format off
3890 static constexpr llvm::Intrinsic::ID IDs[] = {
3891 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3892 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3893 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3894 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3895 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3896 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3897 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3898 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3899 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3900 // clang-format on
3901 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3902
3903 // Tidy-up the Intrinsic Args
3904 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3905 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3906 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3907 if (needCast)
3908 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3909
3910 return {id, {mbar, txcount}};
3911}
3912
3913mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
3914 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3915 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3916 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3917 llvm::Intrinsic::ID id =
3918 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3919 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3920 // Fill the Intrinsic Args
3922 args.push_back(mt.lookupValue(thisOp.getAddr()));
3923 args.push_back(mt.lookupValue(thisOp.getCount()));
3924
3925 return {id, std::move(args)};
3926}
3927
3928mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
3929 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3930 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3931 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3932 llvm::Intrinsic::ID id =
3933 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3934 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3935 // Fill the Intrinsic Args
3937 args.push_back(mt.lookupValue(thisOp.getAddr()));
3938 args.push_back(mt.lookupValue(thisOp.getCount()));
3939
3940 return {id, std::move(args)};
3941}
3942
3943mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
3944 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3945 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3946 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3947 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3948 // bit-0: isPhaseParity
3949 // bit-1: Scope
3950 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3951
3952 // clang-format off
3953 static constexpr llvm::Intrinsic::ID IDs[] = {
3954 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3955 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3956 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3957 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3958 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3959 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3960 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3961 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3962 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3963 // clang-format on
3964 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3965
3966 // Tidy-up the Intrinsic Args
3967 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3968 llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
3969 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3970 if (needCast)
3971 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3972
3973 return {id, {mbar, input}};
3974}
3975
3976mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
3977 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3978 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3979 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3980 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3981 bool hasTicks = static_cast<bool>(thisOp.getTicks());
3982 // bit-0: isPhaseParity
3983 // bit-1: Scope
3984 // bit-2: hasTicks
3985 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3986 (isPhaseParity ? 1 : 0);
3987
3988 // clang-format off
3989 static constexpr llvm::Intrinsic::ID IDs[] = {
3990 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3991 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3992 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3993 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3994 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3995 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3996 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3997 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3998 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3999 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
4000 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
4001 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
4002 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
4003 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
4004 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
4005 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
4006 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
4007 // clang-format on
4008 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
4009
4010 // Tidy-up the mbarrier pointer
4011 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
4012 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
4013 if (needCast)
4014 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
4015
4016 // Fill the Intrinsic Args
4018 args.push_back(mbar);
4019 args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
4020 if (hasTicks)
4021 args.push_back(mt.lookupValue(thisOp.getTicks()));
4022
4023 return {id, std::move(args)};
4024}
4025
4026mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
4027 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4028 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
4029 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
4030
4031 llvm::Intrinsic::ID id;
4032 if (thisOp.getNoinc()) {
4033 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
4034 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
4035 } else {
4036 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
4037 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
4038 }
4039
4040 return {id, {mt.lookupValue(thisOp.getAddr())}};
4041}
4042
4044MovMatrixOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4045 llvm::IRBuilderBase &builder) {
4046 auto thisOp = cast<NVVM::MovMatrixOp>(op);
4047 return {llvm::Intrinsic::nvvm_movmatrix_sync_aligned_m8n8_trans_b16,
4048 {mt.lookupValue(thisOp.getSrc())}};
4049}
4050
4051#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
4052 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
4053
4054#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
4055 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
4056
4057llvm::Intrinsic::ID
4058CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4060 llvm::Intrinsic::ID id;
4061
4062 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
4063 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
4064 switch (cpAsyncOp.getSize()) {
4065 case 4:
4066 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
4067 break;
4068 case 8:
4069 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
4070 break;
4071 case 16:
4072 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
4073 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
4074 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
4075 break;
4076 default:
4077 llvm_unreachable("Invalid copy size in CpAsyncOp.");
4078 }
4079
4080 // Fill the Intrinsic Args
4081 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
4082 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
4083 if (hasCpSize)
4084 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
4085
4086 return id;
4087}
4088
4089mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
4090 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4091 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
4093 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
4094
4095 // Fill the Intrinsic Args
4096 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4097 args.push_back(mt.lookupValue(thisOp.getSize()));
4098
4099 mlir::Value cacheHint = thisOp.getL2CacheHint();
4100 const bool hasCacheHint = static_cast<bool>(cacheHint);
4101 llvm::Value *i64Unused =
4102 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4103 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4104 args.push_back(builder.getInt1(hasCacheHint));
4105
4106 return {id, std::move(args)};
4107}
4108
4109mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
4110 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4111 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
4113
4114 // Fill the Intrinsic Args: dst, mbar, src, size.
4115 args.push_back(mt.lookupValue(thisOp.getDstMem()));
4116 args.push_back(mt.lookupValue(thisOp.getMbar()));
4117 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4118 args.push_back(mt.lookupValue(thisOp.getSize()));
4119
4120 // Multicast mask for shared::cluster only, if available.
4121 mlir::Value multicastMask = thisOp.getMulticastMask();
4122 const bool hasMulticastMask = static_cast<bool>(multicastMask);
4123 const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
4124 if (!isSharedCTA) {
4125 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
4126 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
4127 : i16Unused);
4128 }
4129
4130 // Cache hint, if available.
4131 mlir::Value cacheHint = thisOp.getL2CacheHint();
4132 const bool hasCacheHint = static_cast<bool>(cacheHint);
4133 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
4134 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4135
4136 // Flag arguments for multicast and cachehint.
4137 if (!isSharedCTA)
4138 args.push_back(builder.getInt1(hasMulticastMask));
4139 args.push_back(builder.getInt1(hasCacheHint));
4140
4141 llvm::Intrinsic::ID id =
4142 isSharedCTA
4143 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
4144 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
4145
4146 return {id, std::move(args)};
4147}
4148
4149mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
4150 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4151 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
4153 llvm::Intrinsic::ID id =
4154 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
4155
4156 // Fill the Intrinsic Args
4157 args.push_back(mt.lookupValue(thisOp.getDstMem()));
4158 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4159 args.push_back(mt.lookupValue(thisOp.getSize()));
4160
4161 mlir::Value cacheHint = thisOp.getL2CacheHint();
4162 const bool hasCacheHint = static_cast<bool>(cacheHint);
4163 llvm::Value *i64Unused =
4164 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4165 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4166 args.push_back(builder.getInt1(hasCacheHint));
4167
4168 // Choose the bytemask variant
4169 if (mlir::Value byteMask = thisOp.getByteMask()) {
4170 args.push_back(mt.lookupValue(byteMask));
4171 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
4172 }
4173
4174 return {id, std::move(args)};
4175}
4176
4177bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
4178 RewriterBase &rewriter,
4179 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
4180 &asmValues) {
4181 // Add all the operands but not the attrs to the asmValues list.
4182 // The attrs here are used to generate the right variants for
4183 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
4184 for (auto val : getOperands())
4185 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
4186
4187 return false;
4188}
4189
4191CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
4192 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4193 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
4194 const bool isCTAOnly = thisOp.getIsCTAOnly();
4196
4197 // Fill the Intrinsic Args
4198 args.push_back(mt.lookupValue(thisOp.getDstMem()));
4199 args.push_back(mt.lookupValue(thisOp.getMbar()));
4200 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4201
4202 // Coordinates and im2col-offsets
4203 for (mlir::Value v : thisOp.getCoordinates())
4204 args.push_back(mt.lookupValue(v));
4205 for (mlir::Value v : thisOp.getIm2colOffsets())
4206 args.push_back(mt.lookupValue(v));
4207
4208 // MulticastMask, if available
4209 mlir::Value mcMask = thisOp.getMulticastMask();
4210 const bool hasMC = static_cast<bool>(mcMask);
4211 llvm::Value *i16Zero =
4212 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
4213
4214 // CacheHint, if available
4215 mlir::Value cacheHint = thisOp.getL2CacheHint();
4216 const bool hasCacheHint = static_cast<bool>(cacheHint);
4217 llvm::Value *i64Zero =
4218 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4219
4220 // Flag argument CTAGroup
4221 // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
4222 // Hence, the +1 to getGroup().
4223 const int32_t val =
4224 thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
4225 llvm::Value *cg =
4226 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
4227
4228 if (!isCTAOnly) {
4229 // For shared::cluster, all the arguments that we build are applicable.
4230 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
4231 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
4232 args.push_back(builder.getInt1(hasMC));
4233 args.push_back(builder.getInt1(hasCacheHint));
4234 args.push_back(cg);
4235 } else {
4236 // For shared::cta, only cache-hint is applicable.
4237 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
4238 args.push_back(builder.getInt1(hasCacheHint));
4239 }
4240
4241 constexpr size_t numDims = 5; // 1D to 5D
4242 constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
4243 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
4244 using TableTy = std::array<rowTy, numModes>;
4245 static constexpr TableTy IDTable{
4246 {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
4247 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
4248 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
4249 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
4250 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
4252 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
4253 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
4254 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
4256 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
4257 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
4258 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
4260 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
4261 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
4262 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
4264 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
4265
4266 static constexpr TableTy IDTableCTA{
4267 {{notIntrinsic,
4268 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
4269 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
4270 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
4271 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
4272 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
4274 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
4275 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
4276 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
4278 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
4279 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
4280 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
4282 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
4283 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
4284 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
4286 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
4287
4288 static_assert(
4289 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
4290 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
4291 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
4292 size_t mode = static_cast<size_t>(thisOp.getMode());
4293 size_t dim = thisOp.getCoordinates().size();
4294 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
4295 assert(id != notIntrinsic &&
4296 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
4297
4298 return {id, std::move(args)};
4299}
4300
4301mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
4302 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4303 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
4305
4306 // Fill the Intrinsic Args
4307 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4308
4309 for (auto v : thisOp.getCoordinates())
4310 args.push_back(mt.lookupValue(v));
4311 for (auto v : thisOp.getIm2colOffsets())
4312 args.push_back(mt.lookupValue(v));
4313
4314 mlir::Value cacheHint = thisOp.getL2CacheHint();
4315 const bool hasCacheHint = static_cast<bool>(cacheHint);
4316 llvm::Value *i64Unused =
4317 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4318 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4319 args.push_back(builder.getInt1(hasCacheHint));
4320
4321 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4322 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4323 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
4324 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
4325 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
4326 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
4327 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
4328 {NI, NI, NI,
4329 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
4330 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
4331 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
4332 {NI, NI, NI,
4333 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
4334 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
4335 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
4336 {NI, NI, NI,
4337 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
4338 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
4339 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
4340 {NI, NI, NI, NI, NI,
4341 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
4342
4343 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
4344 "TMALoadModes must match number of rows in IDTable");
4345 size_t mode = static_cast<size_t>(thisOp.getMode());
4346 size_t dim = thisOp.getCoordinates().size();
4347 llvm::Intrinsic::ID id = IDTable[mode][dim];
4348 if (id == llvm::Intrinsic::not_intrinsic)
4349 llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
4350
4351 return {id, std::move(args)};
4352}
4353
4355CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
4356 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4357 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
4359
4360 // Fill the Intrinsic Args
4361 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4362 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4363
4364 for (auto v : thisOp.getCoordinates())
4365 args.push_back(mt.lookupValue(v));
4366
4367 mlir::Value cacheHint = thisOp.getL2CacheHint();
4368 const bool hasCacheHint = static_cast<bool>(cacheHint);
4369 llvm::Value *i64Unused =
4370 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4371 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4372 args.push_back(builder.getInt1(hasCacheHint));
4373
4374 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4375 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4376 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
4377 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
4378 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
4379 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
4380 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
4381 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
4382 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
4383 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
4384 {NI, NI, NI, NI, NI,
4385 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
4386
4387 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
4388 "TMAStoreModes must match number of rows in IDTable");
4389 size_t mode = static_cast<size_t>(thisOp.getMode());
4390 size_t dim = thisOp.getCoordinates().size();
4391 llvm::Intrinsic::ID id = IDTable[mode][dim];
4392 if (id == llvm::Intrinsic::not_intrinsic)
4393 llvm_unreachable(
4394 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
4395
4396 return {id, std::move(args)};
4397}
4398
4399NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
4400 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4401 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
4402 llvm::LLVMContext &ctx = mt.getLLVMContext();
4403
4405
4406 // Arguments to the intrinsic:
4407 // shared_mem_ptr, tmaDesc, tensorDims
4408 // cache_hint(if applicable) and flag(boolean)
4409 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4410 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4411
4412 for (Value v : thisOp.getCoordinates())
4413 args.push_back(mt.lookupValue(v));
4414
4415 mlir::Value cacheHint = thisOp.getL2CacheHint();
4416 const bool hasCacheHint = static_cast<bool>(cacheHint);
4417 llvm::Value *i64ZeroValue =
4418 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
4419 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
4420 args.push_back(builder.getInt1(hasCacheHint));
4421
4422 const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
4423
4424 constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
4425 constexpr unsigned numLayouts = 2; // TILE, IM2COL
4426 constexpr unsigned maxDim = 5; // 1D to 5D
4427 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
4428 using layoutTable = std::array<row, numLayouts>;
4429 using fullTable = std::array<layoutTable, numRedKinds>;
4430 static constexpr fullTable IDTable{
4431 {// RedTy::ADD
4432 {{{{notIntrinsic,
4433 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
4434 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
4435 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
4436 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
4437 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
4439 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
4440 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
4441 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
4442 // RedTy::MIN
4443 {{{{notIntrinsic,
4444 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
4445 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
4446 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
4447 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
4448 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
4450 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
4451 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
4452 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
4453 // RedTy::MAX
4454 {{{{notIntrinsic,
4455 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
4456 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
4457 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
4458 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
4459 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
4461 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
4462 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
4463 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
4464 // RedTy::INC
4465 {{{{notIntrinsic,
4466 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
4467 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
4468 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
4469 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
4470 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
4472 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
4473 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4474 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4475 // RedTy::DEC
4476 {{{{notIntrinsic,
4477 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4478 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4479 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4480 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4481 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4483 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4484 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4485 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4486 // RedTy::AND
4487 {{{{notIntrinsic,
4488 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4489 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4490 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4491 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4492 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4494 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4495 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4496 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4497 // RedTy::OR
4498 {{{{notIntrinsic,
4499 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4500 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4501 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4502 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4503 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4505 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4506 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4507 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4508 // RedTy::XOR
4509 {{{{notIntrinsic,
4510 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4511 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4512 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4513 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4514 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4516 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4517 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4518 llvm::Intrinsic::
4519 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4520
4521 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4522 "TMAReduxKinds must match number of rows in IDTable");
4523
4524 size_t redKind = static_cast<size_t>(thisOp.getRedKind());
4525 size_t mode = static_cast<size_t>(thisOp.getMode());
4526 size_t dim = thisOp.getCoordinates().size();
4527
4528 assert(redKind < IDTable.size() &&
4529 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4530 assert(mode < IDTable[redKind].size() &&
4531 "Invalid mode for CpAsyncBulkTensorReduceOp");
4532 assert(dim < IDTable[redKind][mode].size() &&
4533 "Invalid dim for CpAsyncBulkTensorReduceOp");
4534
4535 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4536
4537 assert(intrinsicID != notIntrinsic &&
4538 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4539
4540 return {intrinsicID, std::move(args)};
4541}
4542
4543#define _none
4544
4545#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4546 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4547 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4548
4549#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4550 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4551 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4552
4553llvm::Intrinsic::ID
4554ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4555 NVVM::SaturationMode sat, bool hasRelu) {
4556 using RndMode = NVVM::FPRoundingMode;
4557 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4558 switch (rnd) {
4559 case RndMode::RN:
4560 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
4561 case RndMode::RZ:
4562 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
4563 case RndMode::RNA:
4564 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
4565 default:
4566 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
4567 }
4568}
4569
4571ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4573 llvm::IRBuilderBase &builder) {
4575 args.push_back(mt.lookupValue(op.getA()));
4576 args.push_back(mt.lookupValue(op.getB()));
4577
4578 bool hasRelu = op.getRelu();
4579
4580 llvm::Intrinsic::ID intId =
4581 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4582 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4583
4584 return {intId, std::move(args)};
4585}
4586
4587#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4588 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4589 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4590
4591llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4592 bool hasRelu) {
4594 .Case([&](mlir::Float6E2M3FNType) {
4595 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
4596 })
4597 .Case([&](mlir::Float6E3M2FNType) {
4598 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
4599 })
4600 .Default([](mlir::Type) {
4601 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
4602 return llvm::Intrinsic::not_intrinsic;
4603 });
4604}
4605
4607ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
4609 llvm::IRBuilderBase &builder) {
4610 mlir::Type dstTy = op.getDstTy();
4611 bool hasRelu = op.getRelu();
4612
4613 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4614
4615 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4616 intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
4617 : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
4618
4620 args.push_back(mt.lookupValue(op.getSrc()));
4621
4622 return {intId, std::move(args)};
4623}
4624
4626ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
4628 llvm::IRBuilderBase &builder) {
4629 mlir::Type dstTy = op.getDstTy();
4630 bool hasRelu = op.getRelu();
4631
4632 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4633
4634 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4635 intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
4636 : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
4637
4639 args.push_back(mt.lookupValue(op.getSrc()));
4640
4641 return {intId, std::move(args)};
4642}
4643
4644llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4645 bool hasRelu) {
4647 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4648 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
4649 : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
4650 })
4651 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4652 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
4653 : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
4654 })
4655 .Default([](mlir::Type) {
4656 llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
4657 return llvm::Intrinsic::not_intrinsic;
4658 });
4659}
4660
4661llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4662 bool hasRelu) {
4664 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4665 return hasRelu
4666 ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
4667 : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
4668 })
4669 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4670 return hasRelu
4671 ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
4672 : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
4673 })
4674 .Default([](mlir::Type) {
4675 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
4676 return llvm::Intrinsic::not_intrinsic;
4677 });
4678}
4679
4680#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4681 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4682 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4683
4684#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4685 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4686 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4687
4688llvm::Intrinsic::ID
4689ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4690 NVVM::SaturationMode sat, bool hasRelu) {
4691 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4692 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4693 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4694
4696 .Case([&](mlir::Float8E4M3FNType) {
4697 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
4698 })
4699 .Case([&](mlir::Float8E5M2Type) {
4700 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
4701 })
4702 .Case([&](mlir::Float8E8M0FNUType) {
4703 if (hasRoundingModeRZ)
4704 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
4705 else if (hasRoundingModeRP)
4706 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
4707
4708 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4709 })
4710 .Default([](mlir::Type) {
4711 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4712 return llvm::Intrinsic::not_intrinsic;
4713 });
4714}
4715
4716#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4717 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4718 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4719
4720llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4721 bool hasRelu) {
4723 .Case([&](mlir::Float8E4M3FNType) {
4724 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
4725 })
4726 .Case([&](mlir::Float8E5M2Type) {
4727 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
4728 })
4729 .Default([](mlir::Type) {
4730 llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
4731 return llvm::Intrinsic::not_intrinsic;
4732 });
4733}
4734
4735llvm::Intrinsic::ID
4736ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4737 NVVM::FPRoundingMode rnd,
4738 NVVM::SaturationMode sat, bool hasRelu) {
4739 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4740
4741 static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
4742 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
4743 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
4744 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
4745 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
4746 };
4747
4749 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4750 return hasRelu
4751 ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
4752 : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
4753 })
4754 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4755 return hasRelu
4756 ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
4757 : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
4758 })
4759 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
4760 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4761 unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
4762 return ue8m0x2IDs[index];
4763 })
4764 .Default([](mlir::Type) {
4765 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
4766 return llvm::Intrinsic::not_intrinsic;
4767 });
4768}
4769
4770NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
4771 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4772 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4773
4774 bool hasRelu = curOp.getRelu();
4775
4776 llvm::Intrinsic::ID intId =
4778 .Case([&](Float8E4M3FNType type) {
4779 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4780 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4781 })
4782 .Case([&](Float8E5M2Type type) {
4783 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4784 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4785 })
4786 .Default([](mlir::Type type) {
4787 llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
4788 return llvm::Intrinsic::not_intrinsic;
4789 });
4790
4791 llvm::Value *packedI16 =
4792 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4793 llvm::Type::getInt16Ty(builder.getContext()));
4794
4795 return {intId, {packedI16}};
4796}
4797
4798NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
4799 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4800 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4801
4802 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4803 llvm::Value *packedI16 =
4804 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4805 llvm::Type::getInt16Ty(builder.getContext()));
4806
4807 return {intId, {packedI16}};
4808}
4809
4810NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
4811 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4812 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4813
4814 bool hasRelu = curOp.getRelu();
4815
4816 llvm::Intrinsic::ID intId =
4818 .Case([&](Float6E2M3FNType type) {
4819 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4820 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4821 })
4822 .Case([&](Float6E3M2FNType type) {
4823 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4824 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4825 })
4826 .Default([](mlir::Type type) {
4827 llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
4828 return llvm::Intrinsic::not_intrinsic;
4829 });
4830
4831 llvm::Value *packedI16 =
4832 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4833 llvm::Type::getInt16Ty(builder.getContext()));
4834
4835 return {intId, {packedI16}};
4836}
4837
4838NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
4839 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4840 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4841
4842 bool hasRelu = curOp.getRelu();
4843
4844 llvm::Intrinsic::ID intId =
4846 .Case([&](Float4E2M1FNType type) {
4847 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4848 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4849 })
4850 .Default([](mlir::Type type) {
4851 llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
4852 return llvm::Intrinsic::not_intrinsic;
4853 });
4854
4855 llvm::Value *extendedI16 =
4856 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
4857 llvm::Type::getInt16Ty(builder.getContext()));
4858
4859 return {intId, {extendedI16}};
4860}
4861
4862NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
4863 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4864 auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
4865 bool hasRelu = thisOp.getRelu();
4866 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
4867
4868 llvm::Intrinsic::ID id =
4869 hasRelu
4870 ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4871 : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4872
4873 // Fill the Intrinsic Args
4875 args.push_back(mt.lookupValue(thisOp.getA()));
4876 args.push_back(mt.lookupValue(thisOp.getB()));
4877 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
4878 : builder.getInt16(0x7f7f));
4879 return {id, std::move(args)};
4880}
4881
4882NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(
4883 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4884 auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
4885 bool hasRelu = thisOp.getRelu();
4886 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
4887
4888 llvm::Intrinsic::ID id =
4889 hasRelu
4890 ? llvm::Intrinsic::
4891 nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4892 : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4893
4894 // Fill the Intrinsic Args
4896 args.push_back(mt.lookupValue(thisOp.getSrc()));
4897 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
4898 : builder.getInt16(0x7f7f));
4899 return {id, std::move(args)};
4900}
4901
4902NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(
4903 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4904 auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
4905 bool hasRelu = thisOp.getRelu();
4906 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
4907 bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
4908
4909 static constexpr llvm::Intrinsic::ID ids[] = {
4910 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
4911 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4912 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4913 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4914 };
4915
4916 unsigned idx = (hasSat << 1) | hasRelu;
4917
4918 // Fill the Intrinsic Args
4920 llvm::Value *packedI16 =
4921 builder.CreateBitCast(mt.lookupValue(thisOp.getSrc()),
4922 llvm::Type::getInt16Ty(builder.getContext()));
4923 args.push_back(packedI16);
4924 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
4925 : builder.getInt16(0x7f7f));
4926
4927 return {ids[idx], std::move(args)};
4928}
4929
4930llvm::Intrinsic::ID
4931Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
4934 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4935 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4936 .getAddressSpace();
4937 bool isShared = as == NVVMMemorySpace::Shared;
4938 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4939
4940 llvm::Intrinsic::ID id;
4941 if (isShared) {
4942 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4943 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4944 } else {
4945 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4946 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4947 }
4948
4949 // Fill the Intrinsic Args
4950 args.push_back(mt.lookupValue(curOp.getAddr()));
4951 args.push_back(mt.lookupValue(curOp.getNCols()));
4952
4953 return id;
4954}
4955
4956llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4959 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4960 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4961 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4962 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4963
4964 // Fill the Intrinsic Args
4965 args.push_back(mt.lookupValue(curOp.getTaddr()));
4966 args.push_back(mt.lookupValue(curOp.getNCols()));
4967
4968 return id;
4969}
4970
4971#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4972 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4973 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4974
4975#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4976 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4977 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4978
4979llvm::Intrinsic::ID
4980Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
4983 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4984 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4985 .getAddressSpace();
4986 bool isShared = as == NVVMMemorySpace::Shared;
4987 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
4988 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4989
4990 llvm::Intrinsic::ID id =
4991 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
4992 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
4993
4994 // Fill the Intrinsic Args
4995 args.push_back(mt.lookupValue(curOp.getAddr()));
4996 if (hasMulticast)
4997 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
4998
4999 return id;
5000}
5001
5002#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
5003 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
5004
5005#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
5006 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
5007 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
5008
5009#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
5010 [&]() -> auto { \
5011 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
5012 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
5013 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
5014 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
5015 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
5016 }()
5017
5019ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
5021 llvm::IRBuilderBase &builder) {
5022 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
5023 llvm::Intrinsic::nvvm_ff2f16x2_rn,
5024 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
5025 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
5026 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
5027 };
5028 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
5029 llvm::Intrinsic::nvvm_ff2f16x2_rz,
5030 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
5031 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
5032 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
5033 };
5034 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
5035 llvm::Intrinsic::nvvm_ff2f16x2_rs,
5036 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
5037 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
5038 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
5039 };
5040
5041 unsigned hasRelu = op.getRelu() ? 1 : 0;
5042 unsigned hasSatFinite =
5043 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
5044 // idx: bit-0 - relu
5045 // bit-1 - satfinite
5046 unsigned idx = (hasSatFinite << 1) | hasRelu;
5047
5049 args.push_back(mt.lookupValue(op.getSrcHi()));
5050 args.push_back(mt.lookupValue(op.getSrcLo()));
5051 if (op.getRandomBits())
5052 args.push_back(mt.lookupValue(op.getRandomBits()));
5053
5054 switch (op.getRnd()) {
5055 case FPRoundingMode::RN:
5056 return {rndRNIds[idx], std::move(args)};
5057 case FPRoundingMode::RZ:
5058 return {rndRZIds[idx], std::move(args)};
5059 case FPRoundingMode::RS:
5060 return {rndRSIds[idx], std::move(args)};
5061 default:
5062 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
5063 }
5064}
5065
5067ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
5069 llvm::IRBuilderBase &builder) {
5070 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
5071 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
5072 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
5073 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
5074 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
5075 };
5076 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
5077 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
5078 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
5079 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
5080 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
5081 };
5082 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
5083 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
5084 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
5085 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
5086 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
5087 };
5088
5089 unsigned hasRelu = op.getRelu() ? 1 : 0;
5090 unsigned hasSatFinite =
5091 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
5092 // idx: bit-0 - relu
5093 // bit-1 - satfinite
5094 unsigned idx = (hasSatFinite << 1) | hasRelu;
5095
5097 args.push_back(mt.lookupValue(op.getSrcHi()));
5098 args.push_back(mt.lookupValue(op.getSrcLo()));
5099 if (op.getRandomBits())
5100 args.push_back(mt.lookupValue(op.getRandomBits()));
5101
5102 switch (op.getRnd()) {
5103 case FPRoundingMode::RN:
5104 return {rndRNIds[idx], std::move(args)};
5105 case FPRoundingMode::RZ:
5106 return {rndRZIds[idx], std::move(args)};
5107 case FPRoundingMode::RS:
5108 return {rndRSIds[idx], std::move(args)};
5109 default:
5110 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
5111 }
5112}
5113
5114llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
5115 mlir::Type dstTy = getDstTy();
5116 bool hasRelu = getRelu();
5117
5119 .Case([&](mlir::Float8E4M3FNType) {
5120 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
5121 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
5122 })
5123 .Case([&](mlir::Float8E5M2Type) {
5124 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
5125 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
5126 })
5127 .Default([](mlir::Type) {
5128 llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
5129 return llvm::Intrinsic::not_intrinsic;
5130 });
5131}
5132
5133llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
5134 mlir::Type dstTy = getDstTy();
5135 bool hasRelu = getRelu();
5136
5138 .Case([&](mlir::Float6E2M3FNType) {
5139 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
5140 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
5141 })
5142 .Case([&](mlir::Float6E3M2FNType) {
5143 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
5144 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
5145 })
5146 .Default([](mlir::Type) {
5147 llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
5148 return llvm::Intrinsic::not_intrinsic;
5149 });
5150}
5151
5152llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
5153 mlir::Type dstTy = getDstTy();
5154 bool hasRelu = getRelu();
5155
5157 .Case([&](mlir::Float4E2M1FNType) {
5158 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
5159 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
5160 })
5161 .Default([](mlir::Type) {
5162 llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
5163 return llvm::Intrinsic::not_intrinsic;
5164 });
5165}
5166
5167llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
5168 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
5169 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
5170 auto srcFmt = curOp.getSrcFormat();
5171 auto mc = curOp.getMulticast();
5172
5173 switch (curOp.getShape()) {
5174 case Tcgen05CpShape::SHAPE_128x256b:
5175 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
5176 case Tcgen05CpShape::SHAPE_128x128b:
5177 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
5178 case Tcgen05CpShape::SHAPE_4x256b:
5179 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
5180 case Tcgen05CpShape::SHAPE_32x128b:
5181 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
5182 case Tcgen05CpShape::SHAPE_64x128b:
5183 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
5184 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
5185 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
5186 }
5187 llvm_unreachable("Invalid shape in tcgen05 cp Op");
5188}
5189
5190// Returns the valid vector length for a given shape and vector length, the
5191// function models the table mentioned in the tcgen05.{ld, st} Op description
5192static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
5193 unsigned vecLen) {
5194 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
5195 return vecLen >= 2;
5196 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
5197 return vecLen >= 4;
5198 return true;
5199}
5200
5201LogicalResult Tcgen05LdOp::verify() {
5202 LogicalResult result = success();
5203 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
5204 result = emitError("shape 16x32bx2 requires offset argument");
5205
5206 if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
5207 result = emitError("offset argument is only supported for shape 16x32bx2");
5208
5209 auto resTy = getRes().getType();
5210 unsigned resLen = isa<VectorType>(resTy)
5211 ? llvm::cast<VectorType>(resTy).getNumElements()
5212 : 1;
5213 if (!isValidVectorLength(getShape(), resLen))
5214 result = emitError(llvm::formatv("invalid result type length {0} for shape "
5215 "{1} in tcgen05.ld Op",
5216 resLen, stringifyEnum(getShape())));
5217
5218 return result;
5219}
5220
5221LogicalResult Tcgen05StOp::verify() {
5222 LogicalResult result = success();
5223 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
5224 result = emitError("shape 16x32bx2 requires offset argument");
5225
5226 auto valTy = getVal().getType();
5227 unsigned valLen = isa<VectorType>(valTy)
5228 ? llvm::cast<VectorType>(valTy).getNumElements()
5229 : 1;
5230 if (!isValidVectorLength(getShape(), valLen))
5231 result = emitError(llvm::formatv("invalid input length {0} for shape "
5232 "{1} in tcgen05.st Op",
5233 valLen, stringifyEnum(getShape())));
5234
5235 return result;
5236}
5237
5238/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
5239/// have ConstantRangeAttr.
5242 SetIntRangeFn setResultRanges) {
5243 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
5244 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
5245 rangeAttr.getLower(), rangeAttr.getUpper()});
5246 } else {
5247 setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
5248 }
5249}
5250
5251/// Verify the range attribute satisfies LLVM ConstantRange constructor
5252/// requirements for NVVM SpecialRangeableRegisterOp.
5253static LogicalResult
5255 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
5256 if (!rangeAttr)
5257 return success();
5258
5259 const llvm::APInt &lower = rangeAttr->getLower();
5260 const llvm::APInt &upper = rangeAttr->getUpper();
5261
5262 // Check LLVM ConstantRange constructor condition
5263 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
5264 unsigned bitWidth = lower.getBitWidth();
5265 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
5266 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
5267 return op->emitOpError(
5268 "invalid range attribute: Lower == Upper, but they aren't min (")
5269 << llvm::toString(minVal, 10, false) << ") or max ("
5270 << llvm::toString(maxVal, 10, false)
5271 << ") value! This is an invalid constant range.";
5272 }
5273
5274 return success();
5275}
5276
5277static llvm::Value *getAsPackedI32(llvm::Value *arg,
5278 llvm::IRBuilderBase &builder) {
5279 return builder.CreateBitCast(arg,
5280 llvm::Type::getInt32Ty(builder.getContext()));
5281}
5282
5283NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
5284 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5285 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
5286
5288 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
5289 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
5290 args.push_back(mt.lookupValue(curOp.getC()));
5291
5292 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
5293 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
5294 unsigned type = (isASigned << 1) | isBSigned;
5295 const llvm::Intrinsic::ID ids[] = {
5296 llvm::Intrinsic::nvvm_idp4a_u_u,
5297 llvm::Intrinsic::nvvm_idp4a_u_s,
5298 llvm::Intrinsic::nvvm_idp4a_s_u,
5299 llvm::Intrinsic::nvvm_idp4a_s_s,
5300 };
5301 return {ids[type], args};
5302}
5303
5304NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
5305 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5306 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
5307
5309 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
5310 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
5311 args.push_back(builder.getInt1(curOp.getBHi()));
5312 args.push_back(mt.lookupValue(curOp.getC()));
5313
5314 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
5315 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
5316 unsigned type = (isASigned << 1) | isBSigned;
5317 const llvm::Intrinsic::ID ids[] = {
5318 llvm::Intrinsic::nvvm_idp2a_u_u,
5319 llvm::Intrinsic::nvvm_idp2a_u_s,
5320 llvm::Intrinsic::nvvm_idp2a_s_u,
5321 llvm::Intrinsic::nvvm_idp2a_s_s,
5322 };
5323 return {ids[type], args};
5324}
5325
5326static llvm::Value *getParamCastedAddr(llvm::Value *addr,
5327 llvm::IRBuilderBase &builder) {
5328 return builder.CreateAddrSpaceCast(
5329 addr, builder.getPtrTy(llvm::NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM));
5330}
5331
5333PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
5335 llvm::IRBuilderBase &builder) {
5336 using MemSpace = NVVM::NVVMMemorySpace;
5337 using CacheLevel = NVVM::PrefetchCacheLevel;
5338
5339 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
5340 std::optional<NVVM::CacheEvictionPriority> evictPriority =
5341 op.getEvictPriority();
5342 unsigned addressSpace =
5343 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
5344 .getAddressSpace();
5345
5347 llvm::Value *addr = mt.lookupValue(op.getAddr());
5348 args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
5349 : addr);
5350
5351 if (op.getTensormap())
5352 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
5353
5354 assert(cacheLevel && "expected cache level for non-tensormap prefetch");
5355
5356 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
5357 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
5358
5359 if (evictPriority && *cacheLevel == CacheLevel::L2) {
5360 switch (*evictPriority) {
5361 case NVVM::CacheEvictionPriority::EvictLast:
5362 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
5363 case NVVM::CacheEvictionPriority::EvictNormal:
5364 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
5365 default:
5366 llvm_unreachable("Invalid cache eviction priority");
5367 }
5368 }
5369
5370 switch (static_cast<MemSpace>(addressSpace)) {
5371 case MemSpace::Generic:
5372 return *cacheLevel == CacheLevel::L1
5373 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
5374 : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
5375 case MemSpace::Global:
5376 return *cacheLevel == CacheLevel::L1
5378 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
5379 : NVVM::IDArgPair(
5380 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
5381 case MemSpace::Local:
5382 return *cacheLevel == CacheLevel::L1
5384 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
5385 : NVVM::IDArgPair(
5386 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
5387 default:
5388 llvm_unreachable("Invalid pointer address space");
5389 }
5390}
5391
5392bool NVVM::InlinePtxOp::getAsmValues(
5393 RewriterBase &rewriter,
5394 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
5395 &asmValues) {
5396 for (auto arg : getReadWriteArgs())
5397 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
5398 for (auto arg : getResults())
5399 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
5400 for (auto arg : getReadOnlyArgs())
5401 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
5402 if (getPredicate())
5403 asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
5404 return false; // No manual mapping needed
5405}
5406
5407NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
5408 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5409 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
5411 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
5412 args.push_back(mt.lookupValue(curOp.getMbarrier()));
5413
5414 llvm::Intrinsic::ID intrinsicID =
5415 curOp.getMulticast()
5416 ? llvm::Intrinsic::
5417 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
5418 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
5419
5420 return {intrinsicID, args};
5421}
5422
5423NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
5424 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5425 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
5427 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
5428
5429 llvm::Intrinsic::ID intrinsicID;
5430
5431 switch (curOp.getQueryType()) {
5432 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
5433 intrinsicID =
5434 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
5435 break;
5436 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
5437 intrinsicID = llvm::Intrinsic::
5438 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
5439 break;
5440 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
5441 intrinsicID = llvm::Intrinsic::
5442 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
5443 break;
5444 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
5445 intrinsicID = llvm::Intrinsic::
5446 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
5447 break;
5448 }
5449 return {intrinsicID, args};
5450}
5451
5453PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
5454 llvm::IRBuilderBase &builder) {
5455 auto thisOp = cast<NVVM::PermuteOp>(op);
5456 NVVM::PermuteMode mode = thisOp.getMode();
5457
5458 static constexpr llvm::Intrinsic::ID IDs[] = {
5459 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
5460 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
5461 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
5462 llvm::Intrinsic::nvvm_prmt_rc16};
5463
5464 unsigned modeIndex = static_cast<unsigned>(mode);
5466 args.push_back(mt.lookupValue(thisOp.getLo()));
5467
5468 // Only first 3 modes (Default, f4e, b4e) need the hi operand.
5469 if (modeIndex < 3)
5470 args.push_back(mt.lookupValue(thisOp.getHi()));
5471
5472 args.push_back(mt.lookupValue(thisOp.getSelector()));
5473
5474 return {IDs[modeIndex], args};
5475}
5476
5477mlir::NVVM::IDArgPair TensormapReplaceOp::getIntrinsicIDAndArgs(
5478 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5479 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
5480
5482 args.push_back(mt.lookupValue(thisOp.getAddr()));
5483 if (thisOp.getOrd())
5484 args.push_back(builder.getInt32(thisOp.getOrd().value()));
5485 if (thisOp.getNewValue())
5486 args.push_back(mt.lookupValue(thisOp.getNewValue()));
5487 if (auto attr = thisOp.getNewValueAttr()) {
5488 auto val =
5490 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
5491 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
5492 TensormapFillModeAttr>([](auto attr) {
5493 return static_cast<unsigned>(attr.getValue());
5494 })
5495 .Default([](auto attr) {
5496 llvm_unreachable("Invalid attribute type");
5497 return 0;
5498 });
5499 args.push_back(builder.getInt32(val));
5500 }
5501
5502 static constexpr llvm::Intrinsic::ID IDs[] = {
5503 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
5504 llvm::Intrinsic::nvvm_tensormap_replace_rank,
5505 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
5506 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
5507 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
5508 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
5509 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
5510 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
5511 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
5512 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
5513 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
5514 };
5515
5516 unsigned fieldIndex = static_cast<unsigned>(thisOp.getField());
5517
5518 return {IDs[fieldIndex], args};
5519}
5520
5521//===----------------------------------------------------------------------===//
5522// NVVM tcgen05.mma functions
5523//===----------------------------------------------------------------------===//
5524
5526Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
5527 llvm::IRBuilderBase &builder) {
5528
5529 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
5531
5532 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5533
5534 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5535 const bool isATensor = isa<llvm::PointerType>(A->getType());
5536 args.push_back(A);
5537
5538 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5539 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5540 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5541
5542 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5543 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5544 using IsATensorArray = std::array<CtaGroupArray, 2>;
5545 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5546 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5547
5548 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5549 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
5550 { // without diable output lane
5551 {{// without scale input D
5552 {{
5553 // shared
5554 {{// cg1
5555 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
5556 // cg2
5557 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
5558 {{// tensor
5559 {
5560 // cg1
5561 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5562 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5563 },
5564 {
5565 // cg2
5566 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5567 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5568 }}},
5569 }},
5570 // with scale input D
5571 {{ // shared
5572 {{// cg1
5573 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
5574 // cg2
5575 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
5576 {{// tensor
5577 {
5578 // cg1
5579 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5580 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5581 },
5582 {
5583 // cg2
5584 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5585 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5586 }}}}}}},
5587 // with disable output lane
5588 {{ // without scale input D
5589 {{ // shared
5590 {{// cg1
5591 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
5592 notIntrinsic},
5593 // cg2
5594 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
5595 notIntrinsic}}},
5596 {{// cg1
5597 {
5598 llvm::Intrinsic::
5599 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
5600 llvm::Intrinsic::
5601 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
5602 },
5603 // cg2
5604 {
5605 llvm::Intrinsic::
5606 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
5607 llvm::Intrinsic::
5608 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
5609 }}}}},
5610 // with scale input D
5611 {{ // shared
5612 {{// cg1
5613 {llvm::Intrinsic::
5614 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
5615 notIntrinsic},
5616 // cg2
5617 {llvm::Intrinsic::
5618 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
5619 notIntrinsic}}},
5620 // tensor
5621 {{// cg1
5622 {llvm::Intrinsic::
5623 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
5624 llvm::Intrinsic::
5625 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
5626 // cg2
5627 {
5628 llvm::Intrinsic::
5629 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5630 llvm::Intrinsic::
5631 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5632 }}}}}}}}};
5633
5634 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5635 bool hasScaleInputD = ScaleInputD != nullptr;
5636
5637 llvm::Value *DisableOutputLane =
5638 mt.lookupValue(thisOp.getDisableOutputLane());
5639 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5640
5641 const unsigned ctaGroup =
5642 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5643
5644 llvm::Intrinsic::ID ID =
5645 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5646 [ctaGroup - 1][thisOp.getAShift()];
5647
5648 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
5649
5650 if (hasScaleInputD)
5651 args.push_back(ScaleInputD);
5652
5653 if (hasDisableOutputLane)
5654 args.push_back(DisableOutputLane);
5655
5656 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5657
5658 if (!hasDisableOutputLane)
5659 args.push_back(builder.getInt32(ctaGroup));
5660
5661 args.push_back(
5662 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5663
5664 return {ID, args};
5665}
5666
5667static LogicalResult
5668verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
5669 NVVM::CTAGroupKind ctaGroup, bool hasAShift,
5670 NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
5671
5672 if (disableOutputLane) {
5673 mlir::VectorType disableOutputLaneType =
5674 cast<mlir::VectorType>(disableOutputLane.getType());
5675 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5676 disableOutputLaneType.getNumElements() != 4) ||
5677 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5678 disableOutputLaneType.getNumElements() != 8))
5679 return emitError(loc) << "Disable Output Lane of length "
5680 << disableOutputLaneType.getNumElements()
5681 << " is incompatible with CtaGroupAttr";
5682 }
5683
5684 if (hasAShift && !isATensor)
5685 return emitError(
5686 loc, "A-shift can be applied only when matrix A is in tensor memory");
5687
5688 if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5689 collectorOp == Tcgen05MMACollectorOp::USE))
5690 return emitError(
5691 loc, "Cannot use collector buffer operation fill or use with ashift");
5692
5693 return success();
5694}
5695
5696LogicalResult Tcgen05MMAOp::verify() {
5697 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5698 getDisableOutputLane(), getCtaGroup(), getAShift(),
5699 getCollectorOp(), getLoc());
5700}
5701
5702//===----------------------------------------------------------------------===//
5703// NVVM tcgen05.mma.sp functions
5704//===----------------------------------------------------------------------===//
5705
5706mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
5707 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5708
5709 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5711
5712 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5713
5714 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5715 bool isATensor = isa<llvm::PointerType>(A->getType());
5716 args.push_back(A);
5717
5718 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5719 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5720 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5721 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5722
5723 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5724 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5725 using IsATensorArray = std::array<CtaGroupArray, 2>;
5726 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5727 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5728
5729 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5730 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5731 { // without diable output lane
5732 {{// without scale input D
5733 {{
5734 // shared
5735 {{// cg1
5736 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
5737 // cg2
5738 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
5739 {{// tensor
5740 {
5741 // cg1
5742 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5743 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5744 },
5745 {
5746 // cg2
5747 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5748 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5749 }}},
5750 }},
5751 // with scale input D
5752 {{ // shared
5753 {{// cg1
5754 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5755 notIntrinsic},
5756 // cg2
5757 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5758 notIntrinsic}}},
5759 {{// tensor
5760 {
5761 // cg1
5762 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5763 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5764 },
5765 {
5766 // cg2
5767 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5768 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5769 }}}}}}},
5770 // with disable output lane
5771 {{ // without scale input D
5772 {{ // shared
5773 {{// cg1
5774 {llvm::Intrinsic::
5775 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5776 notIntrinsic},
5777 // cg2
5778 {llvm::Intrinsic::
5779 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5780 notIntrinsic}}},
5781 {{// cg1
5782 {
5783 llvm::Intrinsic::
5784 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5785 llvm::Intrinsic::
5786 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5787 },
5788 // cg2
5789 {
5790 llvm::Intrinsic::
5791 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5792 llvm::Intrinsic::
5793 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5794 }}}}},
5795 // with scale input D
5796 {{ // shared
5797 {{// cg1
5798 {llvm::Intrinsic::
5799 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5800 notIntrinsic},
5801 // cg2
5802 {llvm::Intrinsic::
5803 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5804 notIntrinsic}}},
5805 // tensor
5806 {{// cg1
5807 {llvm::Intrinsic::
5808 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5809 llvm::Intrinsic::
5810 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5811 // cg2
5812 {
5813 llvm::Intrinsic::
5814 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5815 llvm::Intrinsic::
5816 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5817 }}}}}}}}};
5818
5819 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5820 bool hasScaleInputD = ScaleInputD != nullptr;
5821
5822 llvm::Value *DisableOutputLane =
5823 mt.lookupValue(thisOp.getDisableOutputLane());
5824 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5825
5826 unsigned ctaGroup =
5827 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5828
5829 llvm::Intrinsic::ID ID =
5830 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5831 [ctaGroup - 1][thisOp.getAShift()];
5832
5833 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
5834
5835 if (hasScaleInputD)
5836 args.push_back(ScaleInputD);
5837
5838 if (hasDisableOutputLane)
5839 args.push_back(DisableOutputLane);
5840
5841 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5842
5843 if (!hasDisableOutputLane)
5844 args.push_back(builder.getInt32(ctaGroup));
5845
5846 args.push_back(
5847 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5848
5849 return {ID, args};
5850}
5851
5852LogicalResult Tcgen05MMASparseOp::verify() {
5853 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5854 getDisableOutputLane(), getCtaGroup(), getAShift(),
5855 getCollectorOp(), getLoc());
5856}
5857
5858//===----------------------------------------------------------------------===//
5859// NVVM tcgen05.mma.block_scale functions
5860//===----------------------------------------------------------------------===//
5861
5862mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
5863 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5864
5865 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5867
5868 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5869
5870 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5871 bool isATensor = isa<llvm::PointerType>(A->getType());
5872 args.push_back(A);
5873
5874 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5875 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5876 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5877 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5878 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5879 args.push_back(builder.getInt32(
5880 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5881 args.push_back(
5882 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5883
5884 auto kind = thisOp.getKind();
5885 auto blockScale = thisOp.getBlockScale();
5886 llvm::Intrinsic::ID ID = [&]() {
5887 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5888 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5889 return isATensor ? llvm::Intrinsic::
5890 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5891 : llvm::Intrinsic::
5892 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5893 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5894 return isATensor
5895 ? llvm::Intrinsic::
5896 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5897 : llvm::Intrinsic::
5898 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5899 }
5900 } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5901 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5902 return isATensor
5903 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5904 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5905 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5906 return isATensor ? llvm::Intrinsic::
5907 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5908 : llvm::Intrinsic::
5909 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5910 }
5911 } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5912 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5913 return isATensor
5914 ? llvm::Intrinsic::
5915 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5916 : llvm::Intrinsic::
5917 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5918
5919 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5920 return isATensor
5921 ? llvm::Intrinsic::
5922 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5923 : llvm::Intrinsic::
5924 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5925 }
5926 }
5927 llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
5928 }();
5929
5930 return {ID, args};
5931}
5932
5933static LogicalResult verifyTcgen05MMABlockScaleOp(
5934 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
5935 NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
5936 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5937 kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
5938 return emitError(loc, "mxf4nvf4 requires block scale attribute");
5939
5940 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5941 kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
5942 return emitError(loc,
5943 llvm::formatv("{} kind does not support block16 attribute",
5944 stringifyEnum(kind)));
5945
5946 return success();
5947}
5948
5949LogicalResult Tcgen05MMABlockScaleOp::verify() {
5950 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5951 getBlockScale(), getLoc());
5952}
5953
5954//===----------------------------------------------------------------------===//
5955// NVVM tcgen05.mma.sp.block_scale functions
5956//===----------------------------------------------------------------------===//
5957
5958mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
5959 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5960
5961 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5963
5964 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5965
5966 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5967 bool isATensor = isa<llvm::PointerType>(A->getType());
5968 args.push_back(A);
5969
5970 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5971 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5972 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5973 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5974 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5975 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5976 args.push_back(builder.getInt32(
5977 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5978 args.push_back(
5979 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5980
5981 auto kind = thisOp.getKind();
5982 auto blockScale = thisOp.getBlockScale();
5983 llvm::Intrinsic::ID ID = [&]() {
5984 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5985 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5986 return isATensor ? llvm::Intrinsic::
5987 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5988 : llvm::Intrinsic::
5989 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5990 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5991 return isATensor
5992 ? llvm::Intrinsic::
5993 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5994 : llvm::Intrinsic::
5995 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5996 }
5997 } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5998 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5999 return isATensor ? llvm::Intrinsic::
6000 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
6001 : llvm::Intrinsic::
6002 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
6003 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6004 return isATensor
6005 ? llvm::Intrinsic::
6006 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
6007 : llvm::Intrinsic::
6008 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
6009 }
6010 } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
6011 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6012 return isATensor
6013 ? llvm::Intrinsic::
6014 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
6015 : llvm::Intrinsic::
6016 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
6017
6018 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
6019 return isATensor
6020 ? llvm::Intrinsic::
6021 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
6022 : llvm::Intrinsic::
6023 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
6024 }
6025 }
6026 llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
6027 }();
6028
6029 return {ID, args};
6030}
6031
6032LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
6033 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
6034 getBlockScale(), getLoc());
6035}
6036
6037//===----------------------------------------------------------------------===//
6038// NVVM tcgen05.mma.ws functions
6039//===----------------------------------------------------------------------===//
6040
6041mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
6042 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
6043
6044 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
6046
6047 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
6048
6049 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
6050 bool isATensor = isa<llvm::PointerType>(A->getType());
6051 args.push_back(A);
6052
6053 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
6054 args.push_back(mt.lookupValue(thisOp.getIdesc()));
6055 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
6056
6057 mlir::Value ZeroColMask = thisOp.getZeroColMask();
6058 llvm::Intrinsic::ID ID = notIntrinsic;
6059 if (ZeroColMask) {
6060 args.push_back(mt.lookupValue(ZeroColMask));
6061 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
6062 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
6063 } else
6064 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
6065 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
6066
6067 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
6068 args.push_back(
6069 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
6070 args.push_back(
6071 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
6072
6073 return {ID, args};
6074}
6075
6076//===----------------------------------------------------------------------===//
6077// NVVM tcgen05.mma.ws.sp functions
6078//===----------------------------------------------------------------------===//
6079
6080mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
6081 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
6082
6083 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
6085
6086 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
6087
6088 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
6089 bool isATensor = isa<llvm::PointerType>(A->getType());
6090 args.push_back(A);
6091
6092 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
6093 args.push_back(mt.lookupValue(thisOp.getIdesc()));
6094 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
6095 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
6096
6097 mlir::Value ZeroColMask = thisOp.getZeroColMask();
6098 llvm::Intrinsic::ID ID = notIntrinsic;
6099 if (ZeroColMask) {
6100 args.push_back(mt.lookupValue(ZeroColMask));
6101 ID = isATensor
6102 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
6103 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
6104 } else
6105 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
6106 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
6107
6108 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
6109 args.push_back(
6110 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
6111 args.push_back(
6112 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
6113
6114 return {ID, args};
6115}
6116
6117//===----------------------------------------------------------------------===//
6118// NVVM tcgen05.ld.red functions
6119//===----------------------------------------------------------------------===//
6120
6121#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
6122 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
6123
6124mlir::NVVM::IDArgPair NVVM::Tcgen05LdRedOp::getIntrinsicIDAndArgs(
6125 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
6126 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
6128
6129 mlir::VectorType VecResTy =
6130 cast<mlir::VectorType>(thisOp.getData().getType());
6131 unsigned Num = VecResTy.getNumElements();
6132 bool IsFloat = thisOp.getRedVal().getType().isF32();
6133
6134 llvm::Intrinsic::ID Shape32x32b[][2] = {
6136 {TCGEN05LDRED(32x32b, x2, i32), TCGEN05LDRED(32x32b, x2, f32)},
6137 {TCGEN05LDRED(32x32b, x4, i32), TCGEN05LDRED(32x32b, x4, f32)},
6138 {TCGEN05LDRED(32x32b, x8, i32), TCGEN05LDRED(32x32b, x8, f32)},
6139 {TCGEN05LDRED(32x32b, x16, i32), TCGEN05LDRED(32x32b, x16, f32)},
6140 {TCGEN05LDRED(32x32b, x32, i32), TCGEN05LDRED(32x32b, x32, f32)},
6141 {TCGEN05LDRED(32x32b, x64, i32), TCGEN05LDRED(32x32b, x64, f32)},
6142 {TCGEN05LDRED(32x32b, x128, i32), TCGEN05LDRED(32x32b, x128, f32)},
6143 };
6144
6145 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
6147 {TCGEN05LDRED(16x32bx2, x2, i32), TCGEN05LDRED(16x32bx2, x2, f32)},
6148 {TCGEN05LDRED(16x32bx2, x4, i32), TCGEN05LDRED(16x32bx2, x4, f32)},
6149 {TCGEN05LDRED(16x32bx2, x8, i32), TCGEN05LDRED(16x32bx2, x8, f32)},
6150 {TCGEN05LDRED(16x32bx2, x16, i32), TCGEN05LDRED(16x32bx2, x16, f32)},
6151 {TCGEN05LDRED(16x32bx2, x32, i32), TCGEN05LDRED(16x32bx2, x32, f32)},
6152 {TCGEN05LDRED(16x32bx2, x64, i32), TCGEN05LDRED(16x32bx2, x64, f32)},
6153 {TCGEN05LDRED(16x32bx2, x128, i32), TCGEN05LDRED(16x32bx2, x128, f32)},
6154 };
6155
6156 NVVM::Tcgen05LdStShape shape = thisOp.getShape();
6157 unsigned ID = [&]() {
6158 // `num` contains the length of vector and log2 of `num` returns the index
6159 // into the shape array
6160 unsigned idx = std::log2(Num);
6161 switch (shape) {
6162 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
6163 return Shape32x32b[idx][IsFloat];
6164 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
6165 return Shape16x32bx2[idx][IsFloat];
6166 default:
6167 llvm_unreachable("unhandled tcgen05.ld lowering");
6168 }
6169 }();
6170
6171 args.push_back(mt.lookupValue(thisOp.getAddr()));
6172
6173 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
6174 args.push_back(mt.lookupValue(thisOp.getOffset()));
6175
6176 args.push_back(
6177 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
6178
6179 if (IsFloat) {
6180 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getAbs())));
6181 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getNan())));
6182 }
6183 return {ID, args};
6184}
6185
6186LogicalResult Tcgen05LdRedOp::verify() {
6187 VectorType data = cast<VectorType>(getData().getType());
6188 Type redVal = getRedVal().getType();
6189
6190 if (data.getElementType() != redVal)
6191 return emitError(
6192 "type of reduction value and element type of vector data should match");
6193
6194 if (getOp() != NVVM::ReductionKind::MIN &&
6195 getOp() != NVVM::ReductionKind::MAX)
6196 return emitError("only min and max reduction kinds are supported");
6197
6198 if (redVal.isInteger() && (getAbs() || getNan())) {
6199 return emitError("abs or nan is only applicable for f32 type");
6200 }
6201 return success();
6202}
6203
6204//===----------------------------------------------------------------------===//
6205// NVVMDialect initialization, type parsing, and registration.
6206//===----------------------------------------------------------------------===//
6207
6208// TODO: This should be the llvm.nvvm dialect once this is supported.
6209void NVVMDialect::initialize() {
6210 addOperations<
6211#define GET_OP_LIST
6212#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6213 >();
6214 addAttributes<
6215#define GET_ATTRDEF_LIST
6216#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
6217 >();
6218
6219 // Support unknown operations because not all NVVM operations are
6220 // registered.
6221 allowUnknownOperations();
6222 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
6223 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
6224}
6225
6226LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
6227 NamedAttribute attr) {
6228 StringAttr attrName = attr.getName();
6229 // Kernel function attribute should be attached to functions.
6230 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
6231 if (!isa<LLVM::LLVMFuncOp>(op)) {
6232 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
6233 << "' attribute attached to unexpected op";
6234 }
6235 }
6236 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
6237 // dim
6238 if (attrName == NVVMDialect::getMaxntidAttrName() ||
6239 attrName == NVVMDialect::getReqntidAttrName() ||
6240 attrName == NVVMDialect::getClusterDimAttrName()) {
6241 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
6242 if (!values || values.empty() || values.size() > 3) {
6243 return op->emitError()
6244 << "'" << attrName
6245 << "' attribute must be integer array with maximum 3 index";
6246 }
6247 }
6248 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
6249 // attribute
6250 if (attrName == NVVMDialect::getMinctasmAttrName() ||
6251 attrName == NVVMDialect::getMaxnregAttrName() ||
6252 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
6253 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
6254 return op->emitError()
6255 << "'" << attrName << "' attribute must be integer constant";
6256 }
6257 }
6258 // blocksareclusters must be used along with reqntid and cluster_dim
6259 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
6260 if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
6261 !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
6262 return op->emitError()
6263 << "'" << attrName << "' attribute must be used along with "
6264 << "'" << NVVMDialect::getReqntidAttrName() << "' and "
6265 << "'" << NVVMDialect::getClusterDimAttrName() << "'";
6266 }
6267 }
6268
6269 return success();
6270}
6271
6272LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
6273 unsigned regionIndex,
6274 unsigned argIndex,
6275 NamedAttribute argAttr) {
6276 auto funcOp = dyn_cast<FunctionOpInterface>(op);
6277 if (!funcOp)
6278 return success();
6279
6280 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
6281 StringAttr attrName = argAttr.getName();
6282 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
6283 if (!isKernel) {
6284 return op->emitError()
6285 << "'" << attrName
6286 << "' attribute must be present only on kernel arguments";
6287 }
6288 if (!isa<UnitAttr>(argAttr.getValue()))
6289 return op->emitError() << "'" << attrName << "' must be a unit attribute";
6290 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
6291 return op->emitError()
6292 << "'" << attrName
6293 << "' attribute requires the argument to also have attribute '"
6294 << LLVM::LLVMDialect::getByValAttrName() << "'";
6295 }
6296 }
6297
6298 return success();
6299}
6300
6301//===----------------------------------------------------------------------===//
6302// NVVM Address Space Attr
6303//===----------------------------------------------------------------------===//
6304
6305unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
6306 return static_cast<unsigned>(getValue());
6307}
6308
6309bool NVVMMemorySpaceAttr::isValidLoad(
6310 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
6311 const ::mlir::DataLayout *dataLayout,
6313 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
6314 dataLayout, emitError);
6315}
6316
6317bool NVVMMemorySpaceAttr::isValidStore(
6318 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
6319 const ::mlir::DataLayout *dataLayout,
6321 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
6322 dataLayout, emitError);
6323}
6324
6325bool NVVMMemorySpaceAttr::isValidAtomicOp(
6326 ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
6327 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
6329 // TODO: update this method once `ptr.atomic_rmw` is implemented.
6330 assert(false && "unimplemented, see TODO in the source.");
6331 return false;
6332}
6333
6334bool NVVMMemorySpaceAttr::isValidAtomicXchg(
6335 Type type, ptr::AtomicOrdering successOrdering,
6336 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
6337 const ::mlir::DataLayout *dataLayout,
6339 // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
6340 assert(false && "unimplemented, see TODO in the source.");
6341 return false;
6342}
6343
6344bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
6345 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
6346 // TODO: update this method once the `ptr.addrspace_cast` op is added to the
6347 // dialect.
6348 assert(false && "unimplemented, see TODO in the source.");
6349 return false;
6350}
6351
6352bool NVVMMemorySpaceAttr::isValidPtrIntCast(
6353 Type intLikeTy, Type ptrLikeTy,
6355 // TODO: update this method once the int-cast ops are added to the `ptr`
6356 // dialect.
6357 assert(false && "unimplemented, see TODO in the source.");
6358 return false;
6359}
6360
6361//===----------------------------------------------------------------------===//
6362// NVVM target attribute.
6363//===----------------------------------------------------------------------===//
6364LogicalResult
6365NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
6366 int optLevel, StringRef triple, StringRef chip,
6367 StringRef features, DictionaryAttr flags,
6368 ArrayAttr files, bool verifyTarget) {
6369 if (optLevel < 0 || optLevel > 3) {
6370 emitError() << "The optimization level must be a number between 0 and 3.";
6371 return failure();
6372 }
6373 if (triple.empty()) {
6374 emitError() << "The target triple cannot be empty.";
6375 return failure();
6376 }
6377 if (chip.empty()) {
6378 emitError() << "The target chip cannot be empty.";
6379 return failure();
6380 }
6381 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
6382 return mlir::isa_and_nonnull<StringAttr>(attr);
6383 })) {
6384 emitError() << "All the elements in the `link` array must be strings.";
6385 return failure();
6386 }
6387 return success();
6388}
6389
6390LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
6391 if (!getVerifyTarget())
6392 return success();
6393
6394 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
6395 if (!gpuModuleOp) {
6396 return emitError(gpuModule->getLoc(),
6397 "NVVM target attribute must be attached to a GPU module");
6398 }
6399
6400 const unsigned targetFullSmVersion =
6402 if (!NVVMCheckSMVersion::isMinimumSMVersion(targetFullSmVersion)) {
6403 return emitError(gpuModule->getLoc(),
6404 "Minimum NVVM target SM version is sm_20");
6405 }
6406
6407 if (gpuModuleOp
6408 ->walk([&](Operation *op) {
6409 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
6410 const NVVMCheckSMVersion requirement =
6411 reqOp.getRequiredMinSMVersion();
6412 if (!requirement.isCompatibleWith(targetFullSmVersion)) {
6413 op->emitOpError() << "is not supported on " << getChip();
6414 return WalkResult::interrupt();
6415 }
6416 }
6417 return WalkResult::advance();
6418 })
6419 .wasInterrupted())
6420 return failure();
6421
6422 return success();
6423}
6424
6425#define GET_OP_CLASSES
6426#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6427
6428#define GET_ATTRDEF_CLASSES
6429#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
#define _none
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred, TypeRange actual)
For ops with optional results, allow the user to omit the result even when inference would produce on...
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyAddSubFOp(OpType op)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &regs)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > &regTypes, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static LogicalResult inferMBarrierArriveResultTypes(MLIRContext *context, Value addr, SmallVectorImpl< Type > &inferredReturnTypes)
Only shared_cluster (ptr<7>) produces zero results; all other address spaces (including generic) retu...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerType getI16Type()
Definition Builders.cpp:65
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:575
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:585
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition NVVMDialect.h:53
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult matchAndRewrite(SubFOp op, PatternRewriter &rewriter) const override
static bool isMinimumSMVersion(unsigned fullSmVersion)
static unsigned getTargetFullSmVersionFromStr(StringRef smVersionString)
bool isCompatibleWith(const unsigned &targetFullSmVersion) const
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.