MLIR 23.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
39#include <cassert>
40#include <optional>
41#include <string>
42
43using namespace mlir;
44using namespace NVVM;
45
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
50
51//===----------------------------------------------------------------------===//
52// Helper/Utility methods
53//===----------------------------------------------------------------------===//
54
55static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
57 return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
58}
59
61 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
62}
63
65 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
66}
67
69 return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
70}
71
72static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
73 llvm::Value *ptr,
74 NVVMMemorySpace targetAS) {
75 unsigned AS = static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
78}
79
80// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
81static llvm::nvvm::CTAGroupKind
82getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
83 switch (ctaGroup) {
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
88 }
89 llvm_unreachable("unsupported cta_group value");
90}
91
92//===----------------------------------------------------------------------===//
93// Verifier methods
94//===----------------------------------------------------------------------===//
95
96// This verifier is shared among the following Ops:
97// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
98// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
99static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
100 bool isIm2Col,
101 size_t numIm2ColOffsets,
102 Location loc) {
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc, "expects coordinates between 1 to 5 dimension");
105
106 // For Im2Col mode, there are two constraints:
107 if (isIm2Col) {
108 // 1. Tensor must always be at least 3-d.
109 if (tensorDims < 3)
110 return emitError(
111 loc,
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
113 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
115 return emitError(
116 loc, "im2col offsets must be 2 less than number of coordinates");
117 }
118 return success();
119}
120
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
123 // We lower through inline-ptx when getPredicate() is true.
124 // a) Only TILE mode is supported
125 // b) Cache-hint is not supported
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError("Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
131 }
132
133 size_t dims = getCoordinates().size();
134 switch (mode) {
135 case TMAStoreMode::TILE:
136 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
137 case TMAStoreMode::IM2COL:
138 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
139 case TMAStoreMode::TILE_SCATTER4:
140 if (dims != 5)
141 return emitError("Scatter4 mode expects 5 coordinates");
142 }
143 return success();
144}
145
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError("Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError("expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError("CG cache modifier is only support for 16 bytes copy.");
154 return success();
155}
156
157// This verify params can be shared across TMA Load and Prefetch Ops.
158static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
159 TMALoadMode mode, Location loc) {
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc, "expects coordinates between 1 to 5 dimension");
162
163 auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
166 return emitError(loc)
167 << "to use " << mode
168 << " mode, the tensor has to be at least 3-dimensional";
169
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
172 << " (provided " << numIm2colOff << ")";
173
174 return success();
175 };
176
177 switch (mode) {
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode, false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode, true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode, true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode, false, 0)
188 : emitError(loc, "Gather4 mode expects 5 coordinates");
189 }
190 return success();
191}
192
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
194 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
195 getMode(), getLoc());
196}
197
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) { // Inline-asm based lowering
202 if (isCTAOnly)
203 return emitError("Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
205 return emitError(
206 "Predicate is supported only for Tile and Im2col modes.");
207 } else { // Intrinsics-based lowering
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
211 .getAddressSpace();
212 if (AS != expectedAS)
213 return emitError()
214 << (isCTAOnly
215 ? "Shared::cta destination requires address-space 3."
216 : "Shared::cluster destination requires address-space 7.");
217 // Checks specific to shared::cta mode
218 if (isCTAOnly) {
219 if (getMulticastMask())
220 return emitError("Multicast is not supported with shared::cta mode.");
221 if (getGroup())
222 return emitError("CTAGroup is not supported with shared::cta mode.");
223 }
224 }
225
226 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
227 getMode(), getLoc());
228}
229
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
232 size_t dims = getCoordinates().size();
233 switch (mode) {
234 case TMAStoreMode::TILE:
235 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
236 case TMAStoreMode::IM2COL:
237 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
240 }
241 return success();
242}
243
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
245 bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
246 if (isSharedCTA && getMulticastMask())
247 return emitError("Multicast is not supported with shared::cta mode.");
248
249 return success();
250}
251
252static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
253 NVVM::MemScopeKind scope,
254 Value retVal = nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->emitError("mbarrier scope must be either CTA or Cluster");
257
258 bool isSharedCluster = isPtrInSharedClusterSpace(addr);
259 bool hasRetValue = static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
261 return op->emitError(
262 "mbarrier in shared_cluster space cannot return any value");
263
264 return success();
265}
266
267LogicalResult MBarrierArriveOp::verify() {
268 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
269 getRes());
270}
271
272LogicalResult MBarrierArriveDropOp::verify() {
273 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
274 getRes());
275}
276
277LogicalResult MBarrierArriveExpectTxOp::verify() {
278 // The inline-ptx version of this Op does not support all features.
279 // With predicate, this Op lowers to inline-ptx. So, verify and
280 // error-out if there are unsupported features.
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError("mbarrier scope must be CTA when using predicate");
284
285 if (isPtrInSharedClusterSpace(getAddr()))
286 return emitError("mbarrier in shared_cluster space is not supported when "
287 "using predicate");
288
289 if (getRes())
290 return emitError("return-value is not supported when using predicate");
291
292 if (getRelaxed() == true)
293 return emitError("mbarrier with relaxed semantics is not supported when "
294 "using predicate");
295 }
296 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
297 getRes());
298}
299
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
301 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
302 getRes());
303}
304
305//===----------------------------------------------------------------------===//
306// inferReturnTypes for mbarrier arrive-like ops
307//===----------------------------------------------------------------------===//
308
309/// Only shared_cluster (ptr<7>) produces zero results; all other address
310/// spaces (including generic) return i64.
311static LogicalResult
313 SmallVectorImpl<Type> &inferredReturnTypes) {
314 if (!isPtrInSharedClusterSpace(addr))
315 inferredReturnTypes.push_back(IntegerType::get(context, 64));
316 return success();
317}
318
319LogicalResult
320MBarrierArriveOp::inferReturnTypes(MLIRContext *context,
321 std::optional<Location> location,
322 MBarrierArriveOp::Adaptor adaptor,
323 SmallVectorImpl<Type> &inferredReturnTypes) {
324 return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
325 inferredReturnTypes);
326}
327
328LogicalResult MBarrierArriveDropOp::inferReturnTypes(
329 MLIRContext *context, std::optional<Location> location,
330 MBarrierArriveDropOp::Adaptor adaptor,
331 SmallVectorImpl<Type> &inferredReturnTypes) {
332 return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
333 inferredReturnTypes);
334}
335
336LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
337 MLIRContext *context, std::optional<Location> location,
338 MBarrierArriveExpectTxOp::Adaptor adaptor,
339 SmallVectorImpl<Type> &inferredReturnTypes) {
340 // Predicate forces no return value (inline PTX path).
341 // Note: predicate + shared_cluster is rejected by the verifier separately.
342 if (adaptor.getPredicate())
343 return success();
344 return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
345 inferredReturnTypes);
346}
347
348LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
349 MLIRContext *context, std::optional<Location> location,
350 MBarrierArriveDropExpectTxOp::Adaptor adaptor,
351 SmallVectorImpl<Type> &inferredReturnTypes) {
352 return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
353 inferredReturnTypes);
354}
355
356/// For ops with optional results, allow the user to omit the result even when
357/// inference would produce one. This preserves backward compatibility: the
358/// result can be silently discarded (e.g., for fire-and-forget arrive ops).
360 TypeRange actual) {
361 if (actual.empty())
362 return true;
363 return inferred == actual;
364}
365
366bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
368}
369bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
371}
372bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l,
373 TypeRange r) {
375}
376bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l,
377 TypeRange r) {
379}
380
381LogicalResult MBarrierExpectTxOp::verify() {
382 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
383}
384
385LogicalResult MBarrierCompleteTxOp::verify() {
386 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
387}
388
389LogicalResult MBarrierTestWaitOp::verify() {
390 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
391}
392
393LogicalResult MBarrierTryWaitOp::verify() {
394 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
395}
396
397LogicalResult ConvertFloatToTF32Op::verify() {
398 using RndMode = NVVM::FPRoundingMode;
399 switch (getRnd()) {
400 case RndMode::RNA:
401 if (getRelu())
402 return emitError("Relu not supported with rna rounding mode.");
403 break;
404 case RndMode::RN:
405 case RndMode::RZ:
406 break;
407 default:
408 return emitError(
409 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
410 }
411 return success();
412}
413
414LogicalResult ConvertF32x2ToF6x2Op::verify() {
416
417 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
418 return emitOpError("Only ")
419 << mlir::Float6E2M3FNType::get(ctx) << " and "
420 << mlir::Float6E3M2FNType::get(ctx)
421 << " types are supported for conversions from f32x2 to f6x2.";
422 }
423 return success();
424}
425
426LogicalResult ConvertF32x2ToF8x2Op::verify() {
427 using RndMode = NVVM::FPRoundingMode;
428 using SatMode = NVVM::SaturationMode;
429
430 bool isRoundingModeRN = getRnd() == RndMode::RN;
431 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
432 bool isRoundingModeRP = getRnd() == RndMode::RP;
433 bool isSatFinite = getSat() == SatMode::SATFINITE;
434
435 bool hasRelu = getRelu();
436
438
440 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
441 [&](mlir::Type) -> LogicalResult {
442 if (!isRoundingModeRN) {
443 return emitOpError("Only RN rounding mode is supported for "
444 "conversions from f32x2 to ")
445 << mlir::Float8E4M3FNType::get(ctx) << " and "
446 << mlir::Float8E5M2Type::get(ctx) << " types";
447 }
448 if (!isSatFinite) {
449 return emitOpError("Only SATFINITE saturation mode is supported "
450 "for conversions "
451 "from f32x2 to ")
452 << mlir::Float8E4M3FNType::get(ctx) << " and "
453 << mlir::Float8E5M2Type::get(ctx) << " types";
454 }
455 return success();
456 })
457 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
458 if (!(isRoundingModeRZ || isRoundingModeRP)) {
459 return emitOpError("Only RZ and RP rounding modes are supported for "
460 "conversions from f32x2 to ")
461 << mlir::Float8E8M0FNUType::get(ctx) << " type";
462 }
463 if (hasRelu) {
464 return emitOpError("relu not supported for conversions to ")
465 << mlir::Float8E8M0FNUType::get(ctx) << " type";
466 }
467 return success();
468 })
469 .Default([&](mlir::Type) {
470 return emitOpError("Only ")
471 << mlir::Float8E4M3FNType::get(ctx) << ", "
472 << mlir::Float8E5M2Type::get(ctx) << ", and "
473 << mlir::Float8E8M0FNUType::get(ctx)
474 << " types are "
475 "supported for conversions from f32x2 to f8x2";
476 });
477}
478
479LogicalResult ConvertF16x2ToF8x2Op::verify() {
481
482 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
483 return emitOpError("Only ")
484 << mlir::Float8E4M3FNType::get(ctx) << " and "
485 << mlir::Float8E5M2Type::get(ctx)
486 << " types are supported for conversions from f16x2 to f8x2.";
487 }
488 return success();
489}
490
491LogicalResult ConvertBF16x2ToF8x2Op::verify() {
492 using RndMode = NVVM::FPRoundingMode;
493 using SatMode = NVVM::SaturationMode;
494
495 bool isRoundingModeRN = getRnd() == RndMode::RN;
496 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
497 bool isRoundingModeRP = getRnd() == RndMode::RP;
498 bool isSatFinite = getSat() == SatMode::SATFINITE;
499 bool hasRelu = getRelu();
500
502
504 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
505 [&](mlir::Type) -> LogicalResult {
506 if (!isRoundingModeRN)
507 return emitOpError("Only RN rounding mode is supported for "
508 "conversions from bf16x2 to ")
509 << mlir::Float8E4M3FNType::get(ctx) << " and "
510 << mlir::Float8E5M2Type::get(ctx) << " types";
511 if (!isSatFinite)
512 return emitOpError("Only SATFINITE saturation mode is supported "
513 "for conversions from bf16x2 to ")
514 << mlir::Float8E4M3FNType::get(ctx) << " and "
515 << mlir::Float8E5M2Type::get(ctx) << " types";
516 return success();
517 })
518 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
519 if (!(isRoundingModeRZ || isRoundingModeRP))
520 return emitOpError("Only RZ and RP rounding modes are supported for "
521 "conversions from bf16x2 to ")
522 << mlir::Float8E8M0FNUType::get(ctx) << " type";
523 if (hasRelu)
524 return emitOpError("relu not supported for conversions to ")
525 << mlir::Float8E8M0FNUType::get(ctx) << " type";
526 return success();
527 })
528 .Default([&](mlir::Type) -> LogicalResult {
529 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
530 return failure();
531 });
532}
533
534LogicalResult ConvertF32x2ToF4x2Op::verify() {
536
537 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
538 return emitOpError("Only ")
539 << mlir::Float4E2M1FNType::get(ctx)
540 << " type is supported for conversions from f32x2 to f4x2.";
541
542 return success();
543}
544
545LogicalResult ConvertF8x2ToBF16x2Op::verify() {
547 if (llvm::isa<Float8E8M0FNUType>(getSrcType())) {
548 if (getSat() != SaturationMode::NONE)
549 return emitOpError(
550 "Only NONE saturation mode is supported for conversions from ")
551 << Float8E8M0FNUType::get(ctx) << " type";
552 if (getScaleFactor())
553 return emitOpError("scaleFactor not supported for conversions from ")
554 << Float8E8M0FNUType::get(ctx) << " type";
555 if (getRelu())
556 return emitOpError("relu not supported for conversions from ")
557 << Float8E8M0FNUType::get(ctx) << " type";
558 }
559
560 return success();
561}
562
563LogicalResult PermuteOp::verify() {
564 using Mode = NVVM::PermuteMode;
565 bool hasHi = static_cast<bool>(getHi());
566
567 switch (getMode()) {
568 case Mode::DEFAULT:
569 case Mode::F4E:
570 case Mode::B4E:
571 if (!hasHi)
572 return emitError("mode '") << getMode() << "' requires 'hi' operand.";
573 break;
574 case Mode::RC8:
575 case Mode::ECL:
576 case Mode::ECR:
577 case Mode::RC16:
578 if (hasHi)
579 return emitError("mode '")
580 << getMode() << "' does not accept 'hi' operand.";
581 break;
582 }
583
584 return success();
585}
586
587//===----------------------------------------------------------------------===//
588// Stochastic Rounding Conversion Ops
589//===----------------------------------------------------------------------===//
590
591static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
592 FPRoundingMode rnd,
593 bool hasRandomBits,
594 Operation *op) {
595 static constexpr FPRoundingMode validRndModes[] = {
596 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
597
598 if (!llvm::is_contained(validRndModes, rnd)) {
599 return op->emitOpError(
600 "Only RN, RZ, and RS rounding modes are supported for "
601 "conversions from f32x2 to ")
602 << dstType << ".";
603 }
604
605 if (rnd == FPRoundingMode::RS) {
606 if (!hasRandomBits) {
607 return op->emitOpError("random_bits is required for RS rounding mode.");
608 }
609 } else {
610 if (hasRandomBits) {
611 return op->emitOpError(
612 "random_bits not supported for RN and RZ rounding modes.");
613 }
614 }
615
616 return success();
617}
618
619LogicalResult ConvertF32x2ToF16x2Op::verify() {
620 return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
621 getRandomBits() ? true : false, *this);
622}
623
624LogicalResult ConvertF32x2ToBF16x2Op::verify() {
625 return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
626 getRandomBits() ? true : false, *this);
627}
628
629LogicalResult ConvertF32x4ToF8x4Op::verify() {
631
632 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
633 return emitOpError("Only ")
634 << mlir::Float8E4M3FNType::get(ctx) << " and "
635 << mlir::Float8E5M2Type::get(ctx)
636 << " types are supported for conversions from f32x4 to f8x4.";
637
638 return success();
639}
640
641LogicalResult ConvertF32x4ToF6x4Op::verify() {
643
644 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
645 return emitOpError("Only ")
646 << mlir::Float6E2M3FNType::get(ctx) << " and "
647 << mlir::Float6E3M2FNType::get(ctx)
648 << " types are supported for conversions from f32x4 to f6x4.";
649
650 return success();
651}
652
653LogicalResult ConvertF32x4ToF4x4Op::verify() {
655
656 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
657 return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
658 << " type is supported for conversions from "
659 "f32x4 to f4x4.";
660
661 return success();
662}
663
664LogicalResult BulkStoreOp::verify() {
665 if (getInitVal() != 0)
666 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
667 return success();
668}
669
670LogicalResult PMEventOp::verify() {
671 auto eventId = getEventId();
672 auto maskedEventId = getMaskedEventId();
673 if (!maskedEventId && !eventId) {
674 return emitOpError() << "either `id` or `mask` must be set";
675 }
676
677 if (maskedEventId && eventId) {
678 return emitOpError() << "`id` and `mask` cannot be set at the same time";
679 }
680
681 if (eventId) {
682 if (eventId < 0 || eventId > 15) {
683 return emitOpError() << "`id` must be between 0 and 15";
684 }
685 }
686
687 return llvm::success();
688}
689
690// Given the element type of an operand and whether or not it is an accumulator,
691// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
692// operand's element type.
693std::optional<mlir::NVVM::MMATypes>
694MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
695 auto half2Type =
696 VectorType::get(2, Float16Type::get(operandElType.getContext()));
697 if (operandElType.isF64())
698 return NVVM::MMATypes::f64;
699 if (operandElType.isF16() || operandElType == half2Type)
700 return NVVM::MMATypes::f16;
701 if (operandElType.isF32() && isAccumulator)
702 return NVVM::MMATypes::f32;
703 if (operandElType.isF32() && !isAccumulator)
704 return NVVM::MMATypes::tf32;
705 if (llvm::isa<IntegerType>(operandElType)) {
706 if (isAccumulator)
707 return NVVM::MMATypes::s32;
708 return std::nullopt;
709 }
710
711 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
712 if (structType.getBody().empty())
713 return std::nullopt;
714 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
715 }
716
717 return std::nullopt;
718}
719
720static bool isInt4PtxType(MMATypes type) {
721 return (type == MMATypes::u4 || type == MMATypes::s4);
722}
723
724static bool isInt8PtxType(MMATypes type) {
725 return (type == MMATypes::u8 || type == MMATypes::s8);
726}
727
728static bool isIntegerPtxType(MMATypes type) {
729 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
730 type == MMATypes::s32;
731}
732
733MMATypes MmaOp::accumPtxType() {
734 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
735 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
736 assert(val.has_value() && "accumulator PTX type should always be inferrable");
737 return val.value();
738}
739
740MMATypes MmaOp::resultPtxType() {
741 std::optional<mlir::NVVM::MMATypes> val =
742 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
743 assert(val.has_value() && "result PTX type should always be inferrable");
744 return val.value();
745}
746
747void MmaOp::print(OpAsmPrinter &p) {
748 SmallVector<Type, 4> regTypes;
749 struct MMAOperandFragment {
750 StringRef operandName;
751 StringRef ptxTypeAttr;
752 SmallVector<Value, 4> regs;
753 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
754 : operandName(name), ptxTypeAttr(ptxTypeName) {}
755 };
756
757 std::array<MMAOperandFragment, 3> frags{
758 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
759 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
760 MMAOperandFragment("C", "")};
761 SmallVector<StringRef, 4> ignoreAttrNames{
762 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
763
764 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
765 auto &frag = frags[fragIdx];
766 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
767 for (auto operandIdx = varOperandSpec.first;
768 operandIdx < varOperandSpec.first + varOperandSpec.second;
769 operandIdx++) {
770 frag.regs.push_back(this->getOperand(operandIdx));
771 if (operandIdx == 0) {
772 regTypes.push_back(this->getOperand(operandIdx).getType());
773 }
774 }
775 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
776 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
777 if (inferredType)
778 ignoreAttrNames.push_back(frag.ptxTypeAttr);
779 }
780
781 auto printMmaOperand = [&](const MMAOperandFragment &frag) -> void {
782 p << " " << frag.operandName;
783 p << "[";
784 p.printOperands(frag.regs);
785 p << "] ";
786 };
787
788 for (const auto &frag : frags) {
789 printMmaOperand(frag);
790 }
791
792 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
793
794 // Print the types of the operands and result.
795 p << " : " << "(";
796 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
797 frags[1].regs[0].getType(),
798 frags[2].regs[0].getType()},
799 p);
800 p << ")";
801 p.printArrowTypeList(TypeRange{this->getRes().getType()});
802}
803
804void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
805 ValueRange operandA, ValueRange operandB, ValueRange operandC,
806 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
807 std::optional<MMAIntOverflow> intOverflow,
808 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
809 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
810
811 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
812 MLIRContext *ctx = builder.getContext();
813 result.addAttribute(
814 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
815
816 result.addOperands(operandA);
817 result.addOperands(operandB);
818 result.addOperands(operandC);
819
820 if (multiplicandPtxTypes) {
821 result.addAttribute("multiplicandAPtxType",
822 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
823 result.addAttribute("multiplicandBPtxType",
824 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
825 } else {
826 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
827 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
828 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
829 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
830 }
831
832 if (multiplicandLayouts) {
833 result.addAttribute("layoutA",
834 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
835 result.addAttribute("layoutB",
836 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
837 } else {
838 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
839 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
840 }
841
842 if (intOverflow.has_value())
843 result.addAttribute("intOverflowBehavior",
844 MMAIntOverflowAttr::get(ctx, *intOverflow));
845 if (b1Op.has_value())
846 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
847
848 result.addTypes(resultType);
849 result.addAttribute(
850 MmaOp::getOperandSegmentSizeAttr(),
851 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
852 static_cast<int32_t>(operandB.size()),
853 static_cast<int32_t>(operandC.size())}));
854}
855
856// <operation> :=
857// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
858// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
859// `->` type($res)
860ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
861 struct MMAOperandFragment {
862 std::optional<MMATypes> elemtype;
863 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
864 SmallVector<Type> regTypes;
865 };
866
867 Builder &builder = parser.getBuilder();
868 std::array<MMAOperandFragment, 4> frags;
869
870 NamedAttrList namedAttributes;
871
872 // A helper to parse the operand segments.
873 auto parseMmaOperand = [&](StringRef operandName,
874 MMAOperandFragment &frag) -> LogicalResult {
875 if (parser.parseKeyword(operandName).failed())
876 return failure();
877 if (parser
878 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
879 .failed())
880 return failure();
881 return success();
882 };
883
884 // Parse the operand segments.
885 if (parseMmaOperand("A", frags[0]).failed())
886 return failure();
887 if (parseMmaOperand("B", frags[1]).failed())
888 return failure();
889 if (parseMmaOperand("C", frags[2]).failed())
890 return failure();
891
892 if (parser.parseOptionalAttrDict(namedAttributes).failed())
893 return failure();
894
895 // Parse the type specification and resolve operands.
896 SmallVector<Type, 3> operandTypes;
897 if (failed(parser.parseColon()))
898 return failure();
899 if (failed(parser.parseLParen()))
900 return failure();
901 if (failed(parser.parseTypeList(operandTypes)))
902 return failure();
903 if (failed(parser.parseRParen()))
904 if (operandTypes.size() != 3)
905 return parser.emitError(
906 parser.getNameLoc(),
907 "expected one type for each operand segment but got " +
908 Twine(operandTypes.size()) + " types");
909 for (const auto &iter : llvm::enumerate(operandTypes)) {
910 auto &frag = frags[iter.index()];
911 frag.regTypes.resize(frag.regs.size(), iter.value());
912 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
913 parser.getNameLoc(), result.operands)))
914 return failure();
915 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
916 /*isAccumulator*/ iter.index() < 2);
917 }
918
919 Type resultType;
920 if (parser.parseArrow() || parser.parseType(resultType))
921 return failure();
922 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
923
924 std::array<StringRef, 2> names{"multiplicandAPtxType",
925 "multiplicandBPtxType"};
926 for (unsigned idx = 0; idx < names.size(); idx++) {
927 const auto &frag = frags[idx];
928 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
929 if (!frag.elemtype.has_value() && !attr.has_value()) {
930 return parser.emitError(
931 parser.getNameLoc(),
932 "attribute " + names[idx] +
933 " is not provided explicitly and cannot be inferred");
934 }
935 if (!attr.has_value())
936 result.addAttribute(
937 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
938 }
939
940 result.addTypes(resultType);
941 if (!namedAttributes.empty())
942 result.addAttributes(namedAttributes);
943 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
944 builder.getDenseI32ArrayAttr({
945 static_cast<int32_t>(frags[0].regs.size()),
946 static_cast<int32_t>(frags[1].regs.size()),
947 static_cast<int32_t>(frags[2].regs.size()),
948 }));
949 return success();
950}
951
952LogicalResult MmaOp::verify() {
953 MLIRContext *context = getContext();
954 auto f16Ty = Float16Type::get(context);
955 auto i32Ty = IntegerType::get(context, 32);
956 auto f16x2Ty = VectorType::get(2, f16Ty);
957 auto f32Ty = Float32Type::get(context);
958 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
959 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
960
961 auto s32x4StructTy =
962 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
963 auto f32x8StructTy =
964 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
965 auto f16x2x2StructTy =
966 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
967 auto f32x4StructTy =
968 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
969 auto s32x2StructTy =
970 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
971
972 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
973 getShapeAttr().getK()};
974
975 // These variables define the set of allowed data types for matrices A, B, C,
976 // and result.
977 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
978 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
979 AllowedShapes allowedShapes;
980 AllowedTypes expectedA;
981 AllowedTypes expectedB;
982 AllowedTypes expectedC;
983 SmallVector<Type> expectedResult;
984
985 // When M = 16, we just need to calculate the number of 8xk tiles, where
986 // k is a factor that depends on the data type.
987 if (mmaShape[0] == 16) {
988 int64_t kFactor;
989 Type multiplicandFragType;
990 switch (*getMultiplicandAPtxType()) {
991 case MMATypes::tf32:
992 kFactor = 4;
993 multiplicandFragType = i32Ty;
994 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
995 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
996 break;
997 case MMATypes::bf16:
998 kFactor = 8;
999 multiplicandFragType = i32Ty;
1000 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1001 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1002 break;
1003 case MMATypes::f16:
1004 kFactor = 8;
1005 multiplicandFragType = f16x2Ty;
1006 expectedResult.push_back(f16x2x2StructTy);
1007 expectedResult.push_back(f32x4StructTy);
1008 break;
1009 case MMATypes::s4:
1010 case MMATypes::u4:
1011 kFactor = 32;
1012 break;
1013 case MMATypes::b1:
1014 kFactor = 128;
1015 break;
1016 case MMATypes::s8:
1017 case MMATypes::u8:
1018 kFactor = 16;
1019 break;
1020 default:
1021 return emitError("invalid shape or multiplicand type: ")
1022 << getMultiplicandAPtxType().value();
1023 }
1024
1025 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1026 expectedResult.push_back(s32x4StructTy);
1027 expectedC.emplace_back(4, i32Ty);
1028 multiplicandFragType = i32Ty;
1029 } else {
1030 expectedC.emplace_back(2, f16x2Ty);
1031 expectedC.emplace_back(4, f32Ty);
1032 }
1033
1034 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
1035 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1036 expectedA.emplace_back(unitA, multiplicandFragType);
1037 expectedB.emplace_back(unitB, multiplicandFragType);
1038 allowedShapes.push_back({16, 8, kFactor});
1039 allowedShapes.push_back({16, 8, kFactor * 2});
1040
1041 if (resultPtxType() != accumPtxType())
1042 return emitOpError("ctype does not match dtype");
1043 }
1044
1045 // In the M=8 case, there is only 1 possible case per data type.
1046 if (mmaShape[0] == 8) {
1047 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1048 expectedA.emplace_back(2, f16x2Ty);
1049 expectedB.emplace_back(2, f16x2Ty);
1050 expectedResult.push_back(f16x2x4StructTy);
1051 expectedResult.push_back(f32x8StructTy);
1052 expectedC.emplace_back(4, f16x2Ty);
1053 expectedC.emplace_back(8, f32Ty);
1054 allowedShapes.push_back({8, 8, 4});
1055 }
1056 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1057 Type f64Ty = Float64Type::get(context);
1058 expectedA.emplace_back(1, f64Ty);
1059 expectedB.emplace_back(1, f64Ty);
1060 expectedC.emplace_back(2, f64Ty);
1061 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1062 context, SmallVector<Type>(2, f64Ty)));
1063 allowedShapes.push_back({8, 8, 4});
1064 }
1065 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1066 expectedA.push_back({i32Ty});
1067 expectedB.push_back({i32Ty});
1068 expectedC.push_back({i32Ty, i32Ty});
1069 expectedResult.push_back(s32x2StructTy);
1070 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1071 allowedShapes.push_back({8, 8, 32});
1072 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1073 allowedShapes.push_back({8, 8, 16});
1074 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1075 allowedShapes.push_back({8, 8, 128});
1076 }
1077 }
1078
1079 std::string errorMessage;
1080 llvm::raw_string_ostream errorStream(errorMessage);
1081
1082 // Check that we matched an existing shape/dtype combination.
1083 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1084 !llvm::is_contained(allowedShapes, mmaShape)) {
1085 errorStream << "unimplemented variant for MMA shape <";
1086 llvm::interleaveComma(mmaShape, errorStream);
1087 errorStream << ">";
1088 return emitOpError(errorMessage);
1089 }
1090
1091 // Verify the operand types for segments of A, B, and C operands.
1092 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1093 for (const auto &iter : llvm::enumerate(
1094 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1095 auto spec = this->getODSOperandIndexAndLength(iter.index());
1096 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1097 operand_type_begin() + spec.first +
1098 spec.second);
1099 bool match = llvm::is_contained(iter.value(), operandTySeg);
1100
1101 if (!match) {
1102 errorStream << "Could not match types for the "
1103 << operandNames[iter.index()]
1104 << " operands; expected one of ";
1105 for (const auto &x : iter.value()) {
1106 errorStream << x.size() << "x" << x[0] << " ";
1107 }
1108 errorStream << "but got ";
1109 llvm::interleaveComma(operandTySeg, errorStream);
1110 return emitOpError(errorMessage);
1111 }
1112 }
1113
1114 // Check the result type
1115 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1116 return expectedResultType == getResult().getType();
1117 })) {
1118 errorStream
1119 << "Could not match allowed types for the result; expected one of ";
1120 llvm::interleaveComma(expectedResult, errorStream);
1121 errorStream << " but got " << getResult().getType();
1122 return emitOpError(errorMessage);
1123 }
1124
1125 // Ensure that binary MMA variants have a b1 MMA operation defined.
1126 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1127 return emitOpError("op requires " + getB1OpAttrName().strref() +
1128 " attribute");
1129 }
1130
1131 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1132 // attribute.
1133 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1134 isInt8PtxType(*getMultiplicandAPtxType())) {
1135 if (!getIntOverflowBehavior())
1136 return emitOpError("op requires " +
1137 getIntOverflowBehaviorAttrName().strref() +
1138 " attribute");
1139 }
1140
1141 // Validate layout combinations. According to the operation description, most
1142 // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
1143 // can use other layout combinations.
1144 bool isM8N8K4_F16 =
1145 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1146 getMultiplicandAPtxType() == MMATypes::f16);
1147
1148 if (!isM8N8K4_F16) {
1149 // For all other shapes/types, layoutA must be row and layoutB must be col
1150 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1151 return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
1152 "layoutB = #nvvm.mma_layout<col> for shape <")
1153 << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
1154 << "> with element types " << *getMultiplicandAPtxType() << " and "
1155 << *getMultiplicandBPtxType()
1156 << ". Only m8n8k4 with f16 supports other layouts.";
1157 }
1158 }
1159
1160 return success();
1161}
1162
1163MMATypes MmaSpOp::accumPtxType() {
1164 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1165 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
1166 assert(val.has_value() && "accumulator PTX type should always be inferrable");
1167 return val.value();
1168}
1169
1170MMATypes MmaSpOp::resultPtxType() {
1171 std::optional<mlir::NVVM::MMATypes> val =
1172 MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
1173 assert(val.has_value() && "result PTX type should always be inferrable");
1174 return val.value();
1175}
1176
1178MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1179 llvm::IRBuilderBase &builder) {
1180 auto thisOp = cast<NVVM::MmaSpOp>(op);
1181
1182 // Get operands
1184 for (mlir::Value v : thisOp.getOperands())
1185 args.push_back(mt.lookupValue(v));
1186
1187 // Get intrinsic ID using the existing getIntrinsicID method
1188 auto intId = MmaSpOp::getIntrinsicID(
1189 thisOp.getShape().getM(), thisOp.getShape().getN(),
1190 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1191 thisOp.getOrderedMetadata(), thisOp.getKind(),
1192 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1193 thisOp.accumPtxType(), thisOp.resultPtxType());
1194
1195 return {intId, args};
1196}
1197
1198void MmaSpOp::print(OpAsmPrinter &p) {
1199 SmallVector<Type, 4> regTypes;
1200 struct MMAOperandFragment {
1201 StringRef operandName;
1202 StringRef ptxTypeAttr;
1203 SmallVector<Value, 4> regs;
1204 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1205 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1206 };
1207
1208 std::array<MMAOperandFragment, 5> frags{
1209 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1210 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1211 MMAOperandFragment("C", ""), MMAOperandFragment("sparseMetadata", ""),
1212 MMAOperandFragment("selector", "")};
1213 SmallVector<StringRef, 4> ignoreAttrNames{
1214 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1215
1216 // Handle variadic operands A, B, C
1217 for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1218 auto &frag = frags[fragIdx];
1219 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1220 for (auto operandIdx = varOperandSpec.first;
1221 operandIdx < varOperandSpec.first + varOperandSpec.second;
1222 operandIdx++) {
1223 frag.regs.push_back(this->getOperand(operandIdx));
1224 if (operandIdx == varOperandSpec.first) {
1225 regTypes.push_back(this->getOperand(operandIdx).getType());
1226 }
1227 }
1228 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1229 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
1230 if (inferredType)
1231 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1232 }
1233
1234 // Handle sparse metadata and selector (single operands)
1235 frags[3].regs.push_back(getSparseMetadata());
1236 frags[4].regs.push_back(getSparsitySelector());
1237
1238 auto printMmaSpOperand = [&](const MMAOperandFragment &frag) -> void {
1239 p << " " << frag.operandName;
1240 p << "[";
1241 p.printOperands(frag.regs);
1242 p << "]";
1243 };
1244
1245 for (const auto &frag : frags)
1246 printMmaSpOperand(frag);
1247
1248 p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
1249 p << " : ";
1250 p << "(";
1251 for (int i = 0; i < 3; ++i) {
1252 p << regTypes[i];
1253 if (i < 2)
1254 p << ", ";
1255 }
1256 p << ") -> " << getResult().getType();
1257}
1258
1259void MmaSpOp::build(
1260 OpBuilder &builder, OperationState &result, Type resultType,
1261 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1262 Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
1263 std::optional<MMAIntOverflow> intOverflow,
1264 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1265
1266 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1267 MLIRContext *ctx = builder.getContext();
1268 result.addAttribute(
1269 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1270
1271 result.addOperands(operandA);
1272 result.addOperands(operandB);
1273 result.addOperands(operandC);
1274 result.addOperands(sparseMetadata);
1275 result.addOperands(sparsitySelector);
1276
1277 if (multiplicandPtxTypes) {
1278 result.addAttribute("multiplicandAPtxType",
1279 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1280 result.addAttribute("multiplicandBPtxType",
1281 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1282 } else {
1283 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1284 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1285 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1286 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1287 }
1288
1289 if (intOverflow.has_value())
1290 result.addAttribute("intOverflowBehavior",
1291 MMAIntOverflowAttr::get(ctx, *intOverflow));
1292
1293 result.addTypes(resultType);
1294 result.addAttribute(
1295 MmaSpOp::getOperandSegmentSizeAttr(),
1296 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
1297 static_cast<int32_t>(operandB.size()),
1298 static_cast<int32_t>(operandC.size()), 1,
1299 1})); // sparseMetadata and sparsitySelector
1300}
1301
1302ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
1303 struct MMAOperandFragment {
1304 std::optional<MMATypes> elemtype;
1305 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1306 SmallVector<Type> regTypes;
1307 };
1308
1309 Builder &builder = parser.getBuilder();
1310 std::array<MMAOperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
1311
1312 NamedAttrList namedAttributes;
1313
1314 // A helper to parse the operand segments.
1315 auto parseMmaSpOperand = [&](StringRef operandName,
1316 MMAOperandFragment &frag) -> LogicalResult {
1317 if (parser.parseKeyword(operandName).failed())
1318 return failure();
1319 if (parser
1320 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
1321 .failed())
1322 return failure();
1323 return success();
1324 };
1325
1326 // Parse the operand segments.
1327 if (parseMmaSpOperand("A", frags[0]).failed())
1328 return failure();
1329 if (parseMmaSpOperand("B", frags[1]).failed())
1330 return failure();
1331 if (parseMmaSpOperand("C", frags[2]).failed())
1332 return failure();
1333 if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
1334 return failure();
1335 if (parseMmaSpOperand("selector", frags[4]).failed())
1336 return failure();
1337
1338 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1339 return failure();
1340
1341 // Parse the type specification and resolve operands.
1342 SmallVector<Type, 3> operandTypes;
1343 if (failed(parser.parseColon()))
1344 return failure();
1345 if (failed(parser.parseLParen()))
1346 return failure();
1347 if (failed(parser.parseTypeList(operandTypes)))
1348 return failure();
1349 if (failed(parser.parseRParen()))
1350 return failure();
1351 if (operandTypes.size() != 3)
1352 return parser.emitError(
1353 parser.getNameLoc(),
1354 "expected one type for each operand segment but got " +
1355 Twine(operandTypes.size()) + " types");
1356 for (const auto &iter : llvm::enumerate(operandTypes)) {
1357 auto &frag = frags[iter.index()];
1358 frag.regTypes.resize(frag.regs.size(), iter.value());
1359 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
1360 parser.getNameLoc(), result.operands)))
1361 return failure();
1362 frag.elemtype =
1363 MmaOp::inferOperandMMAType(frag.regTypes[0],
1364 /*isAccumulator*/ iter.index() >= 2);
1365 }
1366
1367 Type resultType;
1368 if (parser.parseArrow() || parser.parseType(resultType))
1369 return failure();
1370 frags[5].elemtype =
1371 MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
1372
1373 // Resolve sparse metadata and selector (assume i32 type)
1374 Type i32Type = builder.getIntegerType(32);
1375 if (parser
1376 .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
1377 result.operands)
1378 .failed())
1379 return failure();
1380 if (parser
1381 .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
1382 result.operands)
1383 .failed())
1384 return failure();
1385
1386 std::array<StringRef, 2> names{"multiplicandAPtxType",
1387 "multiplicandBPtxType"};
1388 for (unsigned idx = 0; idx < names.size(); idx++) {
1389 const auto &frag = frags[idx];
1390 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
1391 if (!frag.elemtype.has_value() && !attr.has_value()) {
1392 return parser.emitError(
1393 parser.getNameLoc(),
1394 "attribute " + names[idx] +
1395 " is not provided explicitly and cannot be inferred");
1396 }
1397 if (!attr.has_value())
1398 result.addAttribute(
1399 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
1400 }
1401
1402 result.addTypes(resultType);
1403 if (!namedAttributes.empty())
1404 result.addAttributes(namedAttributes);
1405 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1406 builder.getDenseI32ArrayAttr({
1407 static_cast<int32_t>(frags[0].regs.size()),
1408 static_cast<int32_t>(frags[1].regs.size()),
1409 static_cast<int32_t>(frags[2].regs.size()),
1410 1, // sparseMetadata
1411 1 // sparsitySelector
1412 }));
1413 return success();
1414}
1415
1416LogicalResult MmaSpOp::verify() {
1417 MLIRContext *context = getContext();
1418 auto f16Ty = Float16Type::get(context);
1419 auto i32Ty = IntegerType::get(context, 32);
1420 auto f16x2Ty = VectorType::get(2, f16Ty);
1421 auto f32Ty = Float32Type::get(context);
1422 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1423 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1424
1425 auto s32x4StructTy =
1426 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1427 auto f32x8StructTy =
1428 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
1429 auto f16x2x2StructTy =
1430 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1431 auto f32x4StructTy =
1432 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1433 auto s32x2StructTy =
1434 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1435
1436 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1437 getShapeAttr().getK()};
1438
1439 // These variables define the set of allowed data types for matrices A, B, C,
1440 // and result.
1441 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
1442 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
1443 AllowedShapes allowedShapes;
1444 AllowedTypes expectedA;
1445 AllowedTypes expectedB;
1446 AllowedTypes expectedC;
1447 SmallVector<Type> expectedResult;
1448
1449 // When M = 16, we just need to calculate the number of 8xk tiles, where
1450 // k is a factor that depends on the data type.
1451 if (mmaShape[0] == 16) {
1452 int64_t kFactor;
1453 Type multiplicandFragType;
1454 switch (*getMultiplicandAPtxType()) {
1455 case MMATypes::tf32:
1456 kFactor = 4;
1457 multiplicandFragType = i32Ty;
1458 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1459 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1460 // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
1461 allowedShapes.push_back({16, 8, 8});
1462 allowedShapes.push_back({16, 8, 16});
1463 break;
1464 case MMATypes::bf16:
1465 kFactor = 8;
1466 multiplicandFragType = i32Ty;
1467 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1468 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1469 // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
1470 allowedShapes.push_back({16, 8, 16});
1471 allowedShapes.push_back({16, 8, 32});
1472 break;
1473 case MMATypes::f16:
1474 kFactor = 8;
1475 multiplicandFragType = f16x2Ty;
1476 expectedResult.push_back(f16x2x2StructTy);
1477 expectedResult.push_back(f32x4StructTy);
1478 // Sparse MMA supports m16n8k16 and m16n8k32 for f16
1479 allowedShapes.push_back({16, 8, 16});
1480 allowedShapes.push_back({16, 8, 32});
1481 break;
1482 case MMATypes::s4:
1483 case MMATypes::u4:
1484 kFactor = 32;
1485 // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
1486 allowedShapes.push_back({16, 8, 64});
1487 allowedShapes.push_back({16, 8, 128});
1488 break;
1489 case MMATypes::s8:
1490 case MMATypes::u8:
1491 kFactor = 16;
1492 // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
1493 allowedShapes.push_back({16, 8, 32});
1494 allowedShapes.push_back({16, 8, 64});
1495 break;
1496 case MMATypes::e4m3:
1497 case MMATypes::e5m2:
1498 case MMATypes::e3m2:
1499 case MMATypes::e2m3:
1500 case MMATypes::e2m1:
1501 kFactor = 16;
1502 multiplicandFragType = i32Ty;
1503 expectedResult.push_back(f16x2x2StructTy);
1504 expectedResult.push_back(f32x4StructTy);
1505 // Sparse MMA supports m16n8k64 for FP8 types
1506 allowedShapes.push_back({16, 8, 64});
1507 break;
1508 default:
1509 return emitError("invalid shape or multiplicand type: ")
1510 << getMultiplicandAPtxType().value();
1511 }
1512
1513 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1514 expectedResult.push_back(s32x4StructTy);
1515 expectedC.emplace_back(4, i32Ty);
1516 multiplicandFragType = i32Ty;
1517 } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1518 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1519 // FP8 types
1520 expectedC.emplace_back(2, f16x2Ty);
1521 expectedC.emplace_back(4, f32Ty);
1522 } else {
1523 expectedC.emplace_back(2, f16x2Ty);
1524 expectedC.emplace_back(4, f32Ty);
1525 }
1526
1527 // For sparse MMA, A operand is compressed (2:4 sparsity means half the
1528 // elements)
1529 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1530 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1531 expectedA.emplace_back(unitA, multiplicandFragType);
1532 expectedB.emplace_back(unitB, multiplicandFragType);
1533
1534 if (resultPtxType() != accumPtxType())
1535 return emitOpError("ctype does not match dtype");
1536 }
1537
1538 // In the M=8 case, there is only 1 possible case per data type.
1539 if (mmaShape[0] == 8) {
1540 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1541 expectedA.emplace_back(2, f16x2Ty);
1542 expectedB.emplace_back(2, f16x2Ty);
1543 expectedResult.push_back(f16x2x4StructTy);
1544 expectedResult.push_back(f32x8StructTy);
1545 expectedC.emplace_back(4, f16x2Ty);
1546 expectedC.emplace_back(8, f32Ty);
1547 allowedShapes.push_back({8, 8, 4});
1548 }
1549 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1550 Type f64Ty = Float64Type::get(context);
1551 expectedA.emplace_back(1, f64Ty);
1552 expectedB.emplace_back(1, f64Ty);
1553 expectedC.emplace_back(2, f64Ty);
1554 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1555 context, SmallVector<Type>(2, f64Ty)));
1556 allowedShapes.push_back({8, 8, 4});
1557 }
1558 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1559 expectedA.push_back({i32Ty});
1560 expectedB.push_back({i32Ty});
1561 expectedC.push_back({i32Ty, i32Ty});
1562 expectedResult.push_back(s32x2StructTy);
1563 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1564 allowedShapes.push_back({8, 8, 32});
1565 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1566 allowedShapes.push_back({8, 8, 16});
1567 }
1568 }
1569
1570 std::string errorMessage;
1571 llvm::raw_string_ostream errorStream(errorMessage);
1572
1573 // Check that we matched an existing shape/dtype combination.
1574 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1575 !llvm::is_contained(allowedShapes, mmaShape)) {
1576 errorStream << "unimplemented variant for MMA shape <";
1577 llvm::interleaveComma(mmaShape, errorStream);
1578 errorStream << ">";
1579 return emitOpError(errorMessage);
1580 }
1581
1582 // Verify the operand types for segments of A, B, and C operands.
1583 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1584 for (const auto &iter : llvm::enumerate(
1585 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1586 auto spec = this->getODSOperandIndexAndLength(iter.index());
1587 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1588 operand_type_begin() + spec.first +
1589 spec.second);
1590 bool match = llvm::is_contained(iter.value(), operandTySeg);
1591
1592 if (!match) {
1593 errorStream << "Could not match types for the "
1594 << operandNames[iter.index()]
1595 << " operands; expected one of ";
1596 for (const auto &x : iter.value()) {
1597 errorStream << x.size() << "x" << x[0] << " ";
1598 }
1599 errorStream << "but got ";
1600 llvm::interleaveComma(operandTySeg, errorStream);
1601 return emitOpError(errorMessage);
1602 }
1603 }
1604
1605 // Check the result type
1606 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1607 return expectedResultType == getResult().getType();
1608 })) {
1609 errorStream
1610 << "Could not match allowed types for the result; expected one of ";
1611 llvm::interleaveComma(expectedResult, errorStream);
1612 errorStream << " but got " << getResult().getType();
1613 return emitOpError(errorMessage);
1614 }
1615
1616 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1617 // attribute.
1618 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1619 isInt8PtxType(*getMultiplicandAPtxType())) {
1620 if (!getIntOverflowBehavior())
1621 return emitOpError("op requires " +
1622 getIntOverflowBehaviorAttrName().strref() +
1623 " attribute");
1624 }
1625
1626 // Validate sparse metadata type (should be i32)
1627 if (!getSparseMetadata().getType().isInteger(32)) {
1628 return emitOpError() << "sparse metadata must be i32 type";
1629 }
1630
1631 // Validate sparsity selector type (should be i32)
1632 if (!getSparsitySelector().getType().isInteger(32)) {
1633 return emitOpError() << "sparsity selector must be i32 type";
1634 }
1635
1636 return success();
1637}
1638
1639//===----------------------------------------------------------------------===//
1640// MMA Block Scale Operations - Shared Helpers
1641//===----------------------------------------------------------------------===//
1642
1643namespace {
1644// Shared structure for MMA operand fragments (A, B, C)
1645struct MMAOperandFragment {
1646 StringRef operandName;
1647 StringRef ptxTypeAttr;
1648 SmallVector<Value, 4> regs;
1649 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1650 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1651};
1652} // namespace
1653
1654// Helper to print operand list in the format: name[operands]
1655static void printOperandList(OpAsmPrinter &p, StringRef name,
1656 ArrayRef<Value> operands) {
1657 p << " " << name << "[";
1658 p.printOperands(operands);
1659 p << "]";
1660}
1661
1662// Helper to parse operand list in the format: name[operands]
1663static LogicalResult
1664parseMmaOperand(OpAsmParser &parser, StringRef operandName,
1666 if (parser.parseKeyword(operandName).failed())
1667 return failure();
1669 .failed())
1670 return failure();
1671 return success();
1672}
1673
1674// Helper to process operand fragments and determine which attributes can be
1675// inferred
1676template <typename Op>
1677static void
1678processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
1679 SmallVectorImpl<Type> &regTypes,
1680 SmallVectorImpl<StringRef> &ignoreAttrNames) {
1681 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1682 auto &frag = frags[fragIdx];
1683 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1684 for (auto operandIdx = varOperandSpec.first;
1685 operandIdx < varOperandSpec.first + varOperandSpec.second;
1686 operandIdx++) {
1687 frag.regs.push_back(op.getOperand(operandIdx));
1688 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1689 regTypes.push_back(op.getOperand(operandIdx).getType());
1690 }
1691 }
1692 if (fragIdx < 2) {
1693 regTypes.push_back(frag.regs[0].getType());
1694 }
1695 std::optional<MMATypes> inferredType =
1696 MmaOp::inferOperandMMAType(regTypes.back(),
1697 /*isAccumulator=*/fragIdx >= 2);
1698 if (inferredType)
1699 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1700 }
1701}
1702
1703// Helper to parse type signature: (A_type, B_type, C_type)
1704static LogicalResult
1706 SmallVectorImpl<Type> &operandTypes) {
1707 if (parser.parseColon().failed() || parser.parseLParen().failed())
1708 return failure();
1709
1710 auto typeParser = [&]() {
1711 Type ty;
1712 if (parser.parseType(ty).failed())
1713 return failure();
1714 operandTypes.push_back(ty);
1715 return success();
1716 };
1717 if (parser.parseCommaSeparatedList(typeParser))
1718 return failure();
1719
1720 if (operandTypes.size() != 3)
1721 return parser.emitError(parser.getCurrentLocation(),
1722 "expected exactly 3 types");
1723
1724 return parser.parseRParen();
1725}
1726
1727// Helper to infer and set multiplicand PTX type attributes
1728static void
1730 const SmallVectorImpl<Type> &operandTypes) {
1731 if (!attrs.get("multiplicandAPtxType")) {
1732 if (auto inferredType =
1733 MmaOp::inferOperandMMAType(operandTypes[0], false)) {
1734 attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1735 }
1736 }
1737 if (!attrs.get("multiplicandBPtxType")) {
1738 if (auto inferredType =
1739 MmaOp::inferOperandMMAType(operandTypes[1], false)) {
1740 attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1741 }
1742 }
1743}
1744
1745// Helper to add common block scale properties
1746template <typename OpType>
1749 ScaleVecSize scaleVecSize,
1750 BlockScaleFormat blockScaleFormat,
1751 MMABlockScaleKind kind) {
1752 MLIRContext *ctx = builder.getContext();
1753 auto &properties = result.getOrAddProperties<typename OpType::Properties>();
1754 properties.setShape(
1755 builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1756 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1757 properties.setBlockScaleFormat(
1758 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1759 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1760}
1761
1762// Helper to infer and add multiplicand PTX types to builder
1765 ValueRange operandB,
1766 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1767 if (multiplicandPtxTypes) {
1768 result.addAttribute("multiplicandAPtxType",
1769 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1770 result.addAttribute("multiplicandBPtxType",
1771 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1772 } else {
1773 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1774 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1775 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1776 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1777 }
1778}
1779
1780// Template helper for common accumPtxType/resultPtxType implementation
1781template <typename OpTy>
1782static MMATypes inferPtxTypeFromResult(OpTy op) {
1783 return *MmaOp::inferOperandMMAType(
1784 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1785 /*isAccumulator=*/true);
1786}
1787
1788//===----------------------------------------------------------------------===//
1789// MmaBlockScaleOp
1790//===----------------------------------------------------------------------===//
1791
1792void MmaBlockScaleOp::print(OpAsmPrinter &p) {
1793 SmallVector<Type, 4> regTypes;
1794 std::array<MMAOperandFragment, 3> frags{
1795 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1796 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1797 MMAOperandFragment("C", "")};
1798 SmallVector<StringRef, 4> ignoreAttrNames{
1799 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1800
1801 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1802
1803 // Print A, B, C operands
1804 for (const auto &frag : frags)
1805 printOperandList(p, frag.operandName, frag.regs);
1806
1807 // Print scale operands
1808 printOperandList(p, "scaleA",
1809 {getScaleAData(), getByteIdA(), getThreadIdA()});
1810 printOperandList(p, "scaleB",
1811 {getScaleBData(), getByteIdB(), getThreadIdB()});
1812
1813 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1814
1815 // Print type signature
1816 p << " : (";
1817 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1818 frags[1].regs[0].getType(),
1819 frags[2].regs[0].getType()},
1820 p);
1821 p << ")";
1822 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1823}
1824
1825ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
1827 struct LocalOperandFragment {
1828 std::optional<MMATypes> elemtype;
1829 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1830 };
1831
1832 Builder &builder = parser.getBuilder();
1833 std::array<LocalOperandFragment, 3> frags;
1834 NamedAttrList namedAttributes;
1835
1836 // Parse A[...] B[...] C[...]
1837 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1838 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1839 parseMmaOperand(parser, "C", frags[2].regs).failed())
1840 return failure();
1841
1842 // Parse scale operands: scaleA[...] scaleB[...]
1843 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
1844 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
1845 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
1846 return failure();
1847
1848 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1849 return failure();
1850
1851 // Parse type signature
1852 SmallVector<Type, 3> operandTypes;
1853 if (parseMmaTypeSignature(parser, operandTypes).failed())
1854 return failure();
1855
1856 // Parse result type
1857 SmallVector<Type, 1> resultTypes;
1858 if (parser.parseArrowTypeList(resultTypes).failed())
1859 return failure();
1860
1861 // Infer element types and resolve operands
1862 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
1863 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1864 /*isAccumulator=*/idx >= 2);
1865 if (parser
1866 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
1867 result.operands)
1868 .failed())
1869 return failure();
1870 }
1871
1872 // Resolve scale operands
1873 SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
1874 builder.getI16Type()};
1875 if (parser
1876 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
1877 result.operands)
1878 .failed() ||
1879 parser
1880 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
1881 result.operands)
1882 .failed())
1883 return failure();
1884
1885 // Add attributes
1886 result.addAttributes(namedAttributes);
1887 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
1888 operandTypes);
1889
1890 result.addTypes(resultTypes);
1891 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1892 builder.getDenseI32ArrayAttr({
1893 static_cast<int32_t>(frags[0].regs.size()),
1894 static_cast<int32_t>(frags[1].regs.size()),
1895 static_cast<int32_t>(frags[2].regs.size()),
1896 1, // scaleAData
1897 1, // byteIdA
1898 1, // threadIdA
1899 1, // scaleBData
1900 1, // byteIdB
1901 1 // threadIdB
1902 }));
1903 return success();
1904}
1905
1906void MmaBlockScaleOp::build(
1907 OpBuilder &builder, OperationState &result, Type resultType,
1908 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1909 Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
1910 Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
1911 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1912 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1913 MMABlockScaleKind kind) {
1914 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1915
1917 blockScaleFormat, kind);
1918
1919 result.addOperands(operandA);
1920 result.addOperands(operandB);
1921 result.addOperands(operandC);
1922 result.addOperands(
1923 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1924
1925 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
1926 multiplicandPtxTypes);
1927
1928 result.addTypes(resultType);
1929 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1930 builder.getDenseI32ArrayAttr({
1931 static_cast<int32_t>(operandA.size()),
1932 static_cast<int32_t>(operandB.size()),
1933 static_cast<int32_t>(operandC.size()),
1934 1, // scaleAData
1935 1, // byteIdA
1936 1, // threadIdA
1937 1, // scaleBData
1938 1, // byteIdB
1939 1 // threadIdB
1940 }));
1941}
1942
1943NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
1944 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1945 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1946
1948 // Add A, B, C operands
1949 for (Value operand : curOp.getOperandA())
1950 args.push_back(mt.lookupValue(operand));
1951 for (Value operand : curOp.getOperandB())
1952 args.push_back(mt.lookupValue(operand));
1953 for (Value operand : curOp.getOperandC())
1954 args.push_back(mt.lookupValue(operand));
1955
1956 // Add scale operands
1957 args.push_back(mt.lookupValue(curOp.getScaleAData()));
1958 args.push_back(mt.lookupValue(curOp.getByteIdA()));
1959 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
1960 args.push_back(mt.lookupValue(curOp.getScaleBData()));
1961 args.push_back(mt.lookupValue(curOp.getByteIdB()));
1962 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
1963
1964 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1965 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1966 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1967 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
1968 curOp.getBlockScaleFormat(), curOp.getKind());
1969
1970 return {intId, args};
1971}
1972
1973LogicalResult MmaBlockScaleOp::verify() {
1974 LogicalResult result = success();
1975 int m = getShape().getM();
1976 int n = getShape().getN();
1977 int k = getShape().getK();
1978
1979 if (m == 16 && n == 8 && k == 64) {
1980 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1981 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1983 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1984 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1985 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1987 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1988 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1990 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1991 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1992 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1993 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1994 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1995 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
1996 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
1997 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
1998 "attributes for mma.m16n8k64.mxf4nvf4");
1999 } else {
2000 result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
2001 }
2002 } else if (m == 16 && n == 8 && k == 32) {
2003 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2004 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2005 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2006 result =
2007 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
2008 "attributes for mma.m16n8k32");
2009 } else {
2010 result = emitOpError("unsupported Geom for mma with block scaling");
2011 }
2012 return result;
2013}
2014
2015//===----------------------------------------------------------------------===//
2016// MmaSpBlockScaleOp
2017//===----------------------------------------------------------------------===//
2018
2019void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
2020 SmallVector<Type, 4> regTypes;
2021 std::array<MMAOperandFragment, 3> frags{
2022 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
2023 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
2024 MMAOperandFragment("C", "")};
2025 SmallVector<StringRef, 4> ignoreAttrNames{
2026 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
2027
2028 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
2029
2030 // Print A, B, C operands
2031 for (const auto &frag : frags)
2032 printOperandList(p, frag.operandName, frag.regs);
2033
2034 // Print sparse-specific operands
2035 printOperandList(p, "sparseMetadata", {getSparseMetadata()});
2036 printOperandList(p, "selector", {getSparsitySelector()});
2037
2038 // Print scale operands
2039 printOperandList(p, "scaleA",
2040 {getScaleAData(), getByteIdA(), getThreadIdA()});
2041 printOperandList(p, "scaleB",
2042 {getScaleBData(), getByteIdB(), getThreadIdB()});
2043
2044 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
2045
2046 // Print type signature
2047 p << " : (";
2048 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
2049 frags[1].regs[0].getType(),
2050 frags[2].regs[0].getType()},
2051 p);
2052 p << ")";
2053 p.printArrowTypeList(TypeRange{this->getRes().getType()});
2054}
2055
2056ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
2058 struct LocalOperandFragment {
2059 std::optional<MMATypes> elemtype;
2060 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
2061 };
2062
2063 Builder &builder = parser.getBuilder();
2064 std::array<LocalOperandFragment, 3> frags;
2065 NamedAttrList namedAttributes;
2066
2067 // Parse A[...] B[...] C[...]
2068 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
2069 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
2070 parseMmaOperand(parser, "C", frags[2].regs).failed())
2071 return failure();
2072
2073 // Parse sparse-specific operands
2075 selectorOperands;
2076 if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
2077 parseMmaOperand(parser, "selector", selectorOperands).failed())
2078 return failure();
2079
2080 // Parse scale operands
2081 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
2082 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
2083 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
2084 return failure();
2085
2086 if (parser.parseOptionalAttrDict(namedAttributes).failed())
2087 return failure();
2088
2089 // Parse type signature
2090 SmallVector<Type, 3> operandTypes;
2091 if (parseMmaTypeSignature(parser, operandTypes).failed())
2092 return failure();
2093
2094 // Parse result type
2095 SmallVector<Type, 1> resultTypes;
2096 if (parser.parseArrowTypeList(resultTypes).failed())
2097 return failure();
2098
2099 // Infer element types and resolve operands
2100 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
2101 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2102 /*isAccumulator=*/idx >= 2);
2103 if (parser
2104 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
2105 result.operands)
2106 .failed())
2107 return failure();
2108 }
2109
2110 // Resolve sparse metadata and selector
2111 Type i32Type = builder.getI32Type();
2112 if (parser
2113 .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
2114 result.operands)
2115 .failed() ||
2116 parser
2117 .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
2118 result.operands)
2119 .failed())
2120 return failure();
2121
2122 // Resolve scale operands
2123 SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
2124 builder.getI16Type()};
2125 if (parser
2126 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
2127 result.operands)
2128 .failed() ||
2129 parser
2130 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
2131 result.operands)
2132 .failed())
2133 return failure();
2134
2135 // Add attributes
2136 result.addAttributes(namedAttributes);
2137 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
2138 operandTypes);
2139
2140 // orderedMetadata is mandatory
2141 if (!result.attributes.get("orderedMetadata"))
2142 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2143
2144 result.addTypes(resultTypes);
2145 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2146 builder.getDenseI32ArrayAttr({
2147 static_cast<int32_t>(frags[0].regs.size()),
2148 static_cast<int32_t>(frags[1].regs.size()),
2149 static_cast<int32_t>(frags[2].regs.size()),
2150 1, // sparseMetadata
2151 1, // sparsitySelector
2152 1, // scaleAData
2153 1, // byteIdA
2154 1, // threadIdA
2155 1, // scaleBData
2156 1, // byteIdB
2157 1 // threadIdB
2158 }));
2159 return success();
2160}
2161
2162void MmaSpBlockScaleOp::build(
2163 OpBuilder &builder, OperationState &result, Type resultType,
2164 ValueRange operandA, ValueRange operandB, ValueRange operandC,
2165 Value sparseMetadata, Value sparsitySelector, Value scaleAData,
2166 Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
2167 Value threadIdB, ArrayRef<int64_t> shape,
2168 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2169 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2170 MMABlockScaleKind kind) {
2171 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
2172
2174 builder, result, shape, scaleVecSize, blockScaleFormat, kind);
2175 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2176
2177 result.addOperands(operandA);
2178 result.addOperands(operandB);
2179 result.addOperands(operandC);
2180 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2181 threadIdA, scaleBData, byteIdB, threadIdB});
2182
2183 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
2184 multiplicandPtxTypes);
2185
2186 result.addTypes(resultType);
2187 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2188 builder.getDenseI32ArrayAttr({
2189 static_cast<int32_t>(operandA.size()),
2190 static_cast<int32_t>(operandB.size()),
2191 static_cast<int32_t>(operandC.size()),
2192 1, // sparseMetadata
2193 1, // sparsitySelector
2194 1, // scaleAData
2195 1, // byteIdA
2196 1, // threadIdA
2197 1, // scaleBData
2198 1, // byteIdB
2199 1 // threadIdB
2200 }));
2201}
2202
2203NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
2204 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2205 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2206
2208 // Add A, B, C operands
2209 for (Value operand : curOp.getOperandA())
2210 args.push_back(mt.lookupValue(operand));
2211 for (Value operand : curOp.getOperandB())
2212 args.push_back(mt.lookupValue(operand));
2213 for (Value operand : curOp.getOperandC())
2214 args.push_back(mt.lookupValue(operand));
2215
2216 // Add sparse metadata and selector
2217 args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
2218 args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
2219
2220 // Add scale operands
2221 args.push_back(mt.lookupValue(curOp.getScaleAData()));
2222 args.push_back(mt.lookupValue(curOp.getByteIdA()));
2223 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
2224 args.push_back(mt.lookupValue(curOp.getScaleBData()));
2225 args.push_back(mt.lookupValue(curOp.getByteIdB()));
2226 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
2227
2228 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2229 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2230 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2231 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
2232 curOp.getBlockScaleFormat(), curOp.getKind());
2233
2234 return {intId, args};
2235}
2236
2237LogicalResult MmaSpBlockScaleOp::verify() {
2238 // Check that orderedMetadata is present
2239 if (!getOrderedMetadata()) {
2240 return emitOpError("'orderedMetadata' attribute is mandatory");
2241 }
2242
2243 LogicalResult result = success();
2244 int m = getShape().getM();
2245 int n = getShape().getN();
2246 int k = getShape().getK();
2247
2248 if (m == 16 && n == 8 && k == 128) {
2249 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2250 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2252 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2253 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2254 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2256 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2257 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2259 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2260 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2261 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2262 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2263 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2264 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2265 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2266 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
2267 "attributes for mma.m16n8k128.mxf4nvf4");
2268 } else {
2269 result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
2270 }
2271 } else if (m == 16 && n == 8 && k == 64) {
2272 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2273 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2274 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2275 result =
2276 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
2277 "attributes for mma.m16n8k64");
2278 } else {
2279 result = emitOpError("unsupported Geom for sparse mma with block scaling");
2280 }
2281 return result;
2282}
2283
2284LogicalResult ShflOp::verify() {
2285 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2286
2287 auto verifyTypeError = [&](Twine desc, Type expectedType,
2288 Type actualType) -> LogicalResult {
2289 return emitOpError("expected " + desc + " to be of type ")
2290 << expectedType << " but got " << actualType << " instead";
2291 };
2292
2293 if (returnStructType) {
2294 if (!getReturnValueAndIsValid())
2295 return emitOpError("\"return_value_and_is_valid\" attribute must be "
2296 "specified when the return type is a struct type");
2297
2298 if (returnStructType.getBody().size() != 2)
2299 return emitOpError("expected return type to be a two-element struct");
2300
2301 llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
2302 auto resultType = returnStruct[0];
2303 if (resultType != getVal().getType())
2304 return verifyTypeError("first element in the returned struct",
2305 getVal().getType(), resultType);
2306
2307 auto predicateType = returnStruct[1];
2308 if (!predicateType.isInteger(1))
2309 return verifyTypeError("second element in the returned struct",
2310 mlir::IntegerType::get(getContext(), 1),
2311 predicateType);
2312 } else {
2313 if (getReturnValueAndIsValid())
2314 return emitOpError("expected return type to be a two-element struct");
2315
2316 if (getType() != getVal().getType())
2317 return verifyTypeError("return type", getVal().getType(), getType());
2318 }
2319 return success();
2320}
2321
2322LogicalResult
2323ShflOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
2324 ShflOp::Adaptor adaptor,
2325 SmallVectorImpl<Type> &inferredReturnTypes) {
2326 Type valType = adaptor.getVal().getType();
2327 if (adaptor.getReturnValueAndIsValid())
2328 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2329 context, {valType, IntegerType::get(context, 1)}));
2330 else
2331 inferredReturnTypes.push_back(valType);
2332 return success();
2333}
2334
2335std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
2336 NVVM::MMAFrag frag, int nRow,
2337 int nCol,
2338 MLIRContext *context) {
2339 unsigned numberElements = 0;
2340 Type elementType;
2341 OpBuilder builder(context);
2342 Type f16x2 = VectorType::get(2, builder.getF16Type());
2343 if (type == NVVM::MMATypes::f16) {
2344 elementType = f16x2;
2345 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2346 numberElements = 8;
2347 else
2348 numberElements = 4;
2349 } else if (type == NVVM::MMATypes::f32) {
2350 elementType = builder.getF32Type();
2351 numberElements = 8;
2352 } else if (type == NVVM::MMATypes::f64) {
2353 elementType = builder.getF64Type();
2354 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2355 numberElements = 1;
2356 else
2357 numberElements = 2;
2358 } else if (type == NVVM::MMATypes::tf32) {
2359 elementType = builder.getI32Type();
2360 numberElements = 4;
2361 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2362 elementType = builder.getI32Type();
2363 int parallelSize = 0;
2364 if (frag == NVVM::MMAFrag::a)
2365 parallelSize = nRow;
2366 if (frag == NVVM::MMAFrag::b)
2367 parallelSize = nCol;
2368
2369 // m == 16 && n == 16 && k == 16
2370 if (parallelSize == 16)
2371 numberElements = 2;
2372 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
2373 else if (parallelSize == 8)
2374 numberElements = 1;
2375 else if (parallelSize == 32)
2376 numberElements = 4;
2377 } else if (type == NVVM::MMATypes::s32) {
2378 elementType = builder.getI32Type();
2379 numberElements = 8;
2380 }
2381 assert(numberElements != 0 && elementType != nullptr);
2382 return std::make_pair(elementType, numberElements);
2383}
2384
2385static std::pair<mlir::Type, unsigned>
2386inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
2387 int k, MLIRContext *context) {
2388 int nRow, nCol;
2389 if (frag == NVVM::MMAFrag::a) {
2390 nRow = m;
2391 nCol = k;
2392 } else if (frag == NVVM::MMAFrag::b) {
2393 nRow = k;
2394 nCol = n;
2395 } else {
2396 nRow = m;
2397 nCol = n;
2398 }
2399 assert(nRow && nCol);
2400 return inferMMAType(type, frag, nRow, nCol, context);
2401}
2402
2403LogicalResult NVVM::WMMALoadOp::verify() {
2404 unsigned addressSpace =
2405 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2406 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2407 addressSpace != NVVMMemorySpace::Shared)
2408 return emitOpError("expected source pointer in memory "
2409 "space 0, 1, 3");
2410
2411 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2412 getEltype(), getFrag()) == 0)
2413 return emitOpError() << "invalid attribute combination";
2414 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2415 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
2416 // Special case for f64 fragments
2417 Type f64Ty = Float64Type::get(getContext());
2418 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2419 if (getType() != f64Ty)
2420 return emitOpError("expected destination type to be f64");
2421 return success();
2422 }
2423 // Everything else is a struct
2424 Type dstType = LLVM::LLVMStructType::getLiteral(
2425 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
2426 if (getType() != dstType)
2427 return emitOpError("expected destination type is a structure of ")
2428 << typeInfo.second << " elements of type " << typeInfo.first;
2429 return success();
2430}
2431
2432LogicalResult NVVM::WMMAStoreOp::verify() {
2433 unsigned addressSpace =
2434 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2435 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2436 addressSpace != NVVMMemorySpace::Shared)
2437 return emitOpError("expected operands to be a source pointer in memory "
2438 "space 0, 1, 3");
2439
2440 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2441 getEltype()) == 0)
2442 return emitOpError() << "invalid attribute combination";
2443 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2444 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2445 if (getArgs().size() != typeInfo.second)
2446 return emitOpError() << "expected " << typeInfo.second << " data operands";
2447 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
2448 return operands.getType() != typeInfo.first;
2449 }))
2450 return emitOpError() << "expected data operands of type " << typeInfo.first;
2451 return success();
2452}
2453
2454LogicalResult NVVM::WMMAMmaOp::verify() {
2455 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
2456 getLayoutB(), getEltypeA(),
2457 getEltypeB()) == 0)
2458 return emitOpError() << "invalid attribute combination";
2459 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
2460 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
2461 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
2462 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
2463 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
2464 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2465 SmallVector<Type, 32> arguments;
2466 arguments.append(typeInfoA.second, typeInfoA.first);
2467 arguments.append(typeInfoB.second, typeInfoB.first);
2468 arguments.append(typeInfoC.second, typeInfoC.first);
2469 unsigned numArgs = arguments.size();
2470 if (getArgs().size() != numArgs)
2471 return emitOpError() << "expected " << numArgs << " arguments";
2472 for (unsigned i = 0; i < numArgs; i++) {
2473 if (getArgs()[i].getType() != arguments[i])
2474 return emitOpError() << "expected argument " << i << " to be of type "
2475 << arguments[i];
2476 }
2477 Type dstType = LLVM::LLVMStructType::getLiteral(
2478 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
2479 if (getType() != dstType)
2480 return emitOpError("expected destination type is a structure of ")
2481 << typeInfoC.second << " elements of type " << typeInfoC.first;
2482 return success();
2483}
2484
2485LogicalResult NVVM::LdMatrixOp::verify() {
2486 uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
2487 if (m == 8 && n == 8) {
2488 if (num != 1 && num != 2 && num != 4) {
2489 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
2490 "matrix");
2491 }
2492 if (getEltType() != LdStMatrixEltType::B16) {
2493 return emitOpError("expected element type to be b16 for 8x8 matrix");
2494 }
2495 } else if (m == 8 && n == 16) {
2496 if (num != 1 && num != 2 && num != 4) {
2497 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
2498 "matrix");
2499 }
2500 if (getLayout() != MMALayout::row) {
2501 return emitOpError("expected layout to be row for 8x16 matrix");
2502 }
2503 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2504 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2505 return emitOpError("expected element type to be b8x16.b4x16_p64 or "
2506 "b8x16.b6x16_p32 for 8x16 matrix");
2507 }
2508 } else if (m == 16 && n == 16) {
2509 if (num != 1 && num != 2) {
2510 return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
2511 "matrix");
2512 }
2513 if (getLayout() != MMALayout::col) {
2514 return emitOpError("expected layout to be col for 16x16 matrix");
2515 }
2516 if (getEltType() != LdStMatrixEltType::B8 &&
2517 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2518 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2519 return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
2520 "b8x16.b6x16_p32 for 16x16 matrix");
2521 }
2522 } else {
2523 return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
2524 }
2525
2526 Type i32 = IntegerType::get(getContext(), 32);
2527 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2528 if (numElements == 1 && getType() != i32)
2529 return emitOpError("expected destination type is i32");
2530 if (numElements == 2 || numElements == 4) {
2531 Type dstType = LLVM::LLVMStructType::getLiteral(
2532 getContext(), SmallVector<Type>(numElements, i32));
2533 if (getType() != dstType)
2534 return emitOpError("expected destination type is a structure of ")
2535 << numElements << " elements of type i32";
2536 }
2537
2538 return success();
2539}
2540
2541LogicalResult LdMatrixOp::inferReturnTypes(
2542 MLIRContext *context, std::optional<Location> location,
2543 LdMatrixOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
2544 uint32_t num = adaptor.getNum();
2545 uint32_t m = adaptor.getShape().getM();
2546 uint32_t n = adaptor.getShape().getN();
2547 uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
2548
2549 Type i32 = IntegerType::get(context, 32);
2550 if (numElements == 1)
2551 inferredReturnTypes.push_back(i32);
2552 else
2553 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2554 context, SmallVector<Type>(numElements, i32)));
2555 return success();
2556}
2557
2558LogicalResult NVVM::StMatrixOp::verify() {
2559 int numMatrix = getSources().size();
2560 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2561 return emitOpError("expected num attribute to be 1, 2 or 4");
2562
2563 int m = getShape().getM(), n = getShape().getN();
2564 if (m == 8 && n == 8) {
2565 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2566 return emitOpError("expected element type to be B16 for 8x8 matrix");
2567 }
2568 } else if (m == 16 && n == 8) {
2569 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2570 return emitOpError("expected element type to be B8 for 16x8 matrix");
2571 }
2572 if (getLayout() != NVVM::MMALayout::col) {
2573 return emitOpError("expected layout to be col for 16x8 matrix");
2574 }
2575 } else {
2576 return emitOpError("expected shape to be 8x8 or 16x8");
2577 }
2578
2579 return success();
2580}
2581
2582LogicalResult NVVM::MovMatrixOp::verify() {
2583 int m = getShape().getM(), n = getShape().getN();
2584 if (m != 8 || n != 8)
2585 return emitOpError("expected shape to be 8x8");
2586 if (getLayout() != NVVM::MMALayout::col)
2587 return emitOpError("expected layout to be col");
2588 if (getEltType() != NVVM::LdStMatrixEltType::B16)
2589 return emitOpError("expected element type to be b16");
2590 return success();
2591}
2592
2593static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
2594 if (typeA == NVVM::WGMMATypes::tf32)
2595 return 8;
2596 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2597 return 16;
2598 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2599 return 32;
2600 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2601 return 32;
2602 if (typeA == NVVM::WGMMATypes::b1)
2603 return 256;
2604 return failure();
2605}
2606
2607static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
2608 NVVM::WGMMATypes typeA,
2609 NVVM::WGMMATypes typeB) {
2610 switch (typeA) {
2611 case NVVM::WGMMATypes::f16:
2612 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2613 typeB == NVVM::WGMMATypes::f16)
2614 return success();
2615 break;
2616 case NVVM::WGMMATypes::tf32:
2617 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2618 return success();
2619 break;
2620 case NVVM::WGMMATypes::u8:
2621 case NVVM::WGMMATypes::s8:
2622 if (typeD == NVVM::WGMMATypes::s32 &&
2623 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2624 return success();
2625 break;
2626 case NVVM::WGMMATypes::b1:
2627 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2628 return success();
2629 break;
2630 case NVVM::WGMMATypes::bf16:
2631 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2632 typeB == NVVM::WGMMATypes::bf16)
2633 return success();
2634 break;
2635 case NVVM::WGMMATypes::e4m3:
2636 case NVVM::WGMMATypes::e5m2:
2637 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2638 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2639 return success();
2640 break;
2641 case WGMMATypes::f32:
2642 case WGMMATypes::s32:
2643 llvm_unreachable("unsupported input types");
2644 break;
2645 }
2646 return failure();
2647}
2648
2649static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
2650 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
2651 72, 80, 88, 96, 104, 112, 120, 128,
2652 136, 144, 152, 160, 168, 176, 184, 192,
2653 200, 208, 216, 224, 232, 240, 248, 256};
2654 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
2655 80, 96, 112, 128, 144, 160,
2656 176, 192, 208, 224, 240, 256};
2657 switch (typeA) {
2658 case WGMMATypes::f16:
2659 case WGMMATypes::tf32:
2660 case WGMMATypes::bf16:
2661 case WGMMATypes::e4m3:
2662 case WGMMATypes::e5m2:
2663 if (llvm::is_contained(allowedN, sizeN))
2664 return success();
2665 break;
2666 case WGMMATypes::u8:
2667 case WGMMATypes::s8:
2668 case WGMMATypes::b1:
2669 if (llvm::is_contained(allowedNshort, sizeN))
2670 return success();
2671 break;
2672 case WGMMATypes::f32:
2673 case WGMMATypes::s32:
2674 llvm_unreachable("unsupported input types");
2675 break;
2676 }
2677 return failure();
2678}
2679
2680LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2681 Value outValue = getResults();
2682 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
2683 if (!stype)
2684 return emitOpError() << "expected results to be struct";
2685 int outputSize = stype.getBody().size();
2686 WGMMATypes typeD = getTypeD();
2687 WGMMATypes typeA = getTypeA();
2688 WGMMATypes typeB = getTypeB();
2689
2690 for (Type t : stype.getBody()) {
2691 if (t != stype.getBody().front())
2692 return emitOpError()
2693 << "all elements in struct must be same type but there is " << t;
2694 }
2695
2696 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2697 typeD != WGMMATypes::s32) {
2698 return emitOpError() << "does not support the given output type " << typeD;
2699 }
2700 if (typeD == WGMMATypes::s32 &&
2701 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2702 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
2703 }
2704
2705 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
2706 return emitOpError() << typeD << " += " << typeA << " * " << typeB
2707 << ", it is not supported.";
2708 }
2709
2710 // Check M
2711 if (getShape().getM() != 64)
2712 return emitOpError() << "shape 'm' must be 64";
2713
2714 // Check K
2715 FailureOr<int> allowedK = getAllowedSizeK(typeA);
2716 if (failed(allowedK) || allowedK.value() != getShape().getK())
2717 return emitOpError() << "shape 'k' must be " << allowedK.value()
2718 << " for input type " << typeA;
2719
2720 // Check N
2721 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
2722 return emitOpError() << "has input type " << typeA << " n is set to "
2723 << getShape().getN() << ", it is not supported.";
2724 }
2725
2726 // Check transpose (only available for f16/bf16)
2727 // Matrices A should be stored in row-major and B in column-major.
2728 // Only f16/bf16 matrices can be stored in either column-major or row-major
2729 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
2730 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2731 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2732 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2733 return emitOpError()
2734 << "given layouts layout_a = " << getLayoutA()
2735 << " and layout_b = " << getLayoutB() << " for input types " << typeA
2736 << " and " << typeB
2737 << " requires transpose. However, this is only supported for: "
2738 << MMATypes::f16 << " and " << MMATypes::bf16;
2739 }
2740
2741 // Check result registers
2742 int expectedOutput = 0;
2743 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2744 expectedOutput = getShape().getN() / 2;
2745 if (typeD == WGMMATypes::f16)
2746 expectedOutput = getShape().getN() / 4;
2747 if (outputSize != expectedOutput) {
2748 return emitOpError() << "results " << expectedOutput
2749 << ", however output struct has " << outputSize
2750 << " elements";
2751 }
2752 // Check satfinite (only available for s32 accumulator)
2753 if (typeD != WGMMATypes::s32 &&
2754 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2755 NVVM::MMAIntOverflow::satfinite) {
2756 return emitOpError()
2757 << " `satfinite` can be only used with s32 accumulator, however "
2758 "the current accumulator is "
2759 << typeD;
2760 }
2761
2762 return success();
2763}
2764
2765std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2766
2767 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
2768 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2769
2770 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2771
2772 int expectedOutputRegisters = 0;
2773 if (getTypeD() == WGMMATypes::f16)
2774 expectedOutputRegisters = getShape().getN() / 4;
2775 else
2776 expectedOutputRegisters = getShape().getN() / 2;
2777
2778 std::string ptx;
2779 llvm::raw_string_ostream ss(ptx);
2780
2781 ss << "{\n"
2782 ".reg .pred p;\n"
2783 "setp.ne.b32 p, $"
2784 << ((expectedOutputRegisters * 2) + 2)
2785 << ", 0;\n"
2786 "wgmma.mma_async.sync.aligned.m"
2787 << m << "n" << n << "k" << k << "." << outputTypeName << "." << getTypeA()
2788 << "." << getTypeB();
2789 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2790 NVVM::MMAIntOverflow::satfinite)
2791 ss << ".satfinite";
2792 ss << " {";
2793 int regCnt = 0;
2794 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2795 ss << "$" << regCnt;
2796 if (regCnt != expectedOutputRegisters - 1)
2797 ss << ", ";
2798 }
2799
2800 ss << "},";
2801 // Need to map read/write registers correctly.
2802 regCnt = (regCnt * 2);
2803 ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
2804 if (getTypeD() != WGMMATypes::s32) {
2805 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
2806 }
2807 // Don't add transpose parameters unless needed.
2808 if (isF16) {
2809 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
2810 }
2811 ss << ";\n"
2812 << "}\n";
2813 return ptx;
2814}
2815
2816bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2817 RewriterBase &rewriter,
2818 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2819 &asmValues) {
2820 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2821 if (getResults())
2822 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
2823 if (getInouts())
2824 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
2825 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
2826 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
2827 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
2829 if (getTypeD() != WGMMATypes::s32) {
2830 asmValues.push_back(
2831 {makeConstantI32(rewriter,
2832 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2834 asmValues.push_back(
2835 {makeConstantI32(rewriter,
2836 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2838 }
2839 if (isF16) {
2840 asmValues.push_back(
2841 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
2843 asmValues.push_back(
2844 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
2846 }
2847 return true; // Has manual mapping
2848}
2849
2850LogicalResult NVVM::FenceProxyOp::verify() {
2851 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2852 return emitOpError() << "async_shared fence requires space attribute";
2853 }
2854 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2855 return emitOpError() << "only async_shared fence can have space attribute";
2856 }
2857 return success();
2858}
2859
2860LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2861 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2862 return emitOpError("uni-directional proxies only support generic for "
2863 "from_proxy attribute");
2864
2865 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2866 return emitOpError("uni-directional proxies only support tensormap "
2867 "for to_proxy attribute");
2868 return success();
2869}
2870
2871LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2872 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2873 return emitOpError("uni-directional proxies only support generic for "
2874 "from_proxy attribute");
2875
2876 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2877 return emitOpError("uni-directional proxies only support tensormap "
2878 "for to_proxy attribute");
2879 return success();
2880}
2881
2882LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2883 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2884 return emitOpError("only generic is support for from_proxy attribute");
2885
2886 if (getToProxy() != NVVM::ProxyKind::async)
2887 return emitOpError("only async is supported for to_proxy attribute");
2888 return success();
2889}
2890
2891LogicalResult NVVM::SetMaxRegisterOp::verify() {
2892 if (getRegCount() % 8)
2893 return emitOpError("new register size must be multiple of 8");
2894 if (getRegCount() < 24 || getRegCount() > 256)
2895 return emitOpError("new register size must be in between 24 to 256");
2896 return success();
2897}
2898
2899LogicalResult NVVM::Tcgen05CpOp::verify() {
2900 auto mc = getMulticast();
2901
2902 using SH = Tcgen05CpShape;
2903 using MC = Tcgen05CpMulticast;
2904 switch (getShape()) {
2905 case SH::SHAPE_128x256b:
2906 case SH::SHAPE_128x128b:
2907 case SH::SHAPE_4x256b:
2908 if (mc != MC::NONE)
2909 return emitError("Invalid multicast type for tcgen05.cp Op");
2910 break;
2911 case SH::SHAPE_64x128b:
2912 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2913 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
2914 "warpx2_02_13 for tcgen05.cp Op");
2915 break;
2916 case SH::SHAPE_32x128b:
2917 if (mc != MC::WARPX4)
2918 return emitError(
2919 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2920 break;
2921 }
2922 return success();
2923}
2924
2925LogicalResult NVVM::MatchSyncOp::verify() {
2926 if (getKind() == NVVM::MatchSyncKind::all) {
2927 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2928 if (!type || type.getBody().size() != 2 ||
2929 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2930 return emitOpError("match.sync 'all' returns a two element struct with "
2931 "first element as i32 and second element as i1");
2932 }
2933 } else {
2934 if (!getType().isInteger(32)) {
2935 return emitOpError("match.sync 'any' returns an i32");
2936 }
2937 }
2938 return success();
2939}
2940
2941LogicalResult MatchSyncOp::inferReturnTypes(
2942 MLIRContext *context, std::optional<Location> location,
2943 MatchSyncOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
2944 if (adaptor.getKind() == NVVM::MatchSyncKind::all)
2945 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2946 context,
2947 {IntegerType::get(context, 32), IntegerType::get(context, 1)}));
2948 else
2949 inferredReturnTypes.push_back(IntegerType::get(context, 32));
2950 return success();
2951}
2952
2953LogicalResult NVVM::VoteSyncOp::verify() {
2954 if (getKind() == NVVM::VoteSyncKind::ballot) {
2955 if (!getType().isInteger(32)) {
2956 return emitOpError("vote.sync 'ballot' returns an i32");
2957 }
2958 } else {
2959 if (!getType().isInteger(1)) {
2960 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
2961 }
2962 }
2963 return success();
2964}
2965
2966LogicalResult VoteSyncOp::inferReturnTypes(
2967 MLIRContext *context, std::optional<Location> location,
2968 VoteSyncOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
2969 unsigned width = adaptor.getKind() == NVVM::VoteSyncKind::ballot ? 32 : 1;
2970 inferredReturnTypes.push_back(IntegerType::get(context, width));
2971 return success();
2972}
2973
2974LogicalResult NVVM::PrefetchOp::verify() {
2975 using MemSpace = NVVM::NVVMMemorySpace;
2976 using CacheLevel = NVVM::PrefetchCacheLevel;
2977
2978 unsigned addressSpace =
2979 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
2980 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2981 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2982
2983 if (getTensormap() && cacheLevel)
2984 return emitOpError("cannot specify both tensormap and cache level");
2985
2986 if (getTensormap()) {
2987 if (addressSpace != MemSpace::Generic &&
2988 addressSpace != MemSpace::Constant) {
2989 return emitOpError(
2990 "prefetch tensormap requires a generic or constant pointer");
2991 }
2992
2993 if (evictPriority) {
2994 return emitOpError(
2995 "prefetch tensormap does not support eviction priority");
2996 }
2997
2998 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2999 return emitOpError(
3000 "in_param_space can only be specified for a generic pointer");
3001 }
3002
3003 } else if (cacheLevel) {
3004 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
3005 addressSpace != MemSpace::Local) {
3006 return emitOpError("prefetch to cache level requires a generic, global, "
3007 "or local pointer");
3008 }
3009
3010 if (getUniform()) {
3011 if (*cacheLevel != CacheLevel::L1) {
3012 return emitOpError(
3013 "unsupported cache level, the only supported uniform "
3014 "cache level is L1");
3015 }
3016
3017 if (addressSpace != MemSpace::Generic) {
3018 return emitOpError(
3019 "prefetch to uniform cache requires a generic pointer");
3020 }
3021 }
3022
3023 if (evictPriority) {
3024 if (*cacheLevel != CacheLevel::L2)
3025 return emitOpError(
3026 "cache eviction priority supported only for cache level L2");
3027
3028 if (addressSpace != MemSpace::Global)
3029 return emitOpError("cache eviction priority requires a global pointer");
3030
3031 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
3032 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
3033 return emitOpError(
3034 "unsupported cache eviction priority, only evict_last and "
3035 "evict_normal are supported");
3036 }
3037
3038 if (getPredicate())
3039 return emitOpError("predicate supported only on prefetch tensormap");
3040
3041 } else {
3042 return emitOpError(
3043 "requires specification of either cache level or tensormap");
3044 }
3045
3046 return success();
3047}
3048
3049LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
3050 switch (getQueryType()) {
3051 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
3052 if (!getType().isInteger(1))
3053 return emitOpError("is_canceled query type returns an i1");
3054 break;
3055 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
3056 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
3057 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
3058 if (!getType().isInteger(32)) {
3059 return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
3060 "get_first_cta_id_z query types return an i32");
3061 }
3062 break;
3063 }
3064 return success();
3065}
3066
3067LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
3068 MLIRContext *context, std::optional<Location> location,
3069 ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
3070 SmallVectorImpl<Type> &inferredReturnTypes) {
3071 unsigned width =
3072 adaptor.getQueryType() == NVVM::ClusterLaunchControlQueryType::IS_CANCELED
3073 ? 1
3074 : 32;
3075 inferredReturnTypes.push_back(IntegerType::get(context, width));
3076 return success();
3077}
3078
3079LogicalResult NVVM::ReduxOp::verify() {
3080 mlir::Type reduxType = getType();
3081
3082 if (!reduxType.isF32()) {
3083 if (getAbs())
3084 return emitOpError("abs attribute is supported only for f32 type");
3085 if (getNan())
3086 return emitOpError("nan attribute is supported only for f32 type");
3087 }
3088
3089 NVVM::ReductionKind kind = getKind();
3090 switch (kind) {
3091 case NVVM::ReductionKind::ADD:
3092 case NVVM::ReductionKind::AND:
3093 case NVVM::ReductionKind::OR:
3094 case NVVM::ReductionKind::XOR:
3095 case NVVM::ReductionKind::MAX:
3096 case NVVM::ReductionKind::MIN:
3097 case NVVM::ReductionKind::UMAX:
3098 case NVVM::ReductionKind::UMIN:
3099 if (!reduxType.isInteger(32))
3100 return emitOpError("'")
3101 << kind << "' reduction kind unsupported with " << reduxType
3102 << " type. Only supported type is 'i32'.";
3103 break;
3104 case NVVM::ReductionKind::FMIN:
3105 case NVVM::ReductionKind::FMAX:
3106 if (!reduxType.isF32())
3107 return emitOpError("'")
3108 << kind << "' reduction kind unsupported with " << reduxType
3109 << " type. Only supported type is 'f32'.";
3110 break;
3111 }
3112
3113 return success();
3114}
3115
3116LogicalResult NVVM::TensormapReplaceOp::verify() {
3117 auto ord = getOrd();
3118 Value newVal = getNewValue();
3119 auto newValAttr = getNewValueAttr();
3120 auto fieldName = stringifyEnum(getField());
3121
3122 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
3123 NVVM::TensormapField::GLOBAL_DIM,
3124 NVVM::TensormapField::GLOBAL_STRIDE,
3125 NVVM::TensormapField::ELEMENT_STRIDE},
3126 getField()))
3127 return emitOpError("ordinal is not supported for ")
3128 << fieldName << " field";
3129
3130 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3131 return llvm::Twine("new_value must be specified and must be an " + type +
3132 " for " + llvm::Twine(fieldName) + " field")
3133 .str();
3134 };
3135
3136 auto invalidNewValAttr = [&]() -> std::string {
3137 return (llvm::Twine(
3138 "new_value_attr must be specified and must be a valid ") +
3139 llvm::Twine(fieldName) + " attribute for " + fieldName + " field")
3140 .str();
3141 };
3142
3143 switch (getField()) {
3144 case NVVM::TensormapField::GLOBAL_ADDRESS:
3145 if (!(newVal && newVal.getType().isInteger(64)))
3146 return emitOpError(invalidNewVal("i64"));
3147 break;
3148 case NVVM::TensormapField::RANK:
3149 if (!(newVal && newVal.getType().isInteger(32)))
3150 return emitOpError(invalidNewVal("i32"));
3151 break;
3152 case NVVM::TensormapField::GLOBAL_STRIDE:
3153 if (!ord)
3154 return emitOpError("ordinal is required for global_stride field");
3155 if (!(newVal && newVal.getType().isInteger(64)))
3156 return emitOpError(invalidNewVal("i64"));
3157 break;
3158 case NVVM::TensormapField::BOX_DIM:
3159 case NVVM::TensormapField::GLOBAL_DIM:
3160 case NVVM::TensormapField::ELEMENT_STRIDE:
3161 if (!ord)
3162 return emitOpError("ordinal is required for ")
3163 << stringifyEnum(getField()) << " field";
3164 if (!(newVal && newVal.getType().isInteger(32)))
3165 return emitOpError(invalidNewVal("i32"));
3166 break;
3167 case NVVM::TensormapField::ELEMTYPE:
3168 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3169 return emitOpError(invalidNewValAttr());
3170 break;
3171 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3172 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3173 return emitOpError(invalidNewValAttr());
3174 break;
3175 case NVVM::TensormapField::SWIZZLE_MODE:
3176 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3177 return emitOpError(invalidNewValAttr());
3178 break;
3179 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3180 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3181 return emitOpError(invalidNewValAttr());
3182 break;
3183 case NVVM::TensormapField::FILL_MODE:
3184 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3185 return emitOpError(invalidNewValAttr());
3186 break;
3187 }
3188
3189 return success();
3190}
3191
3192template <typename OpType>
3193static LogicalResult verifyAddSubFOp(OpType op) {
3194 mlir::NVVM::FPRoundingMode rndMode = op.getRnd();
3195 mlir::NVVM::SaturationMode satMode = op.getSat();
3196 bool isFTZ = op.getFtz();
3197
3198 mlir::Type opType = op.getRes().getType();
3199 mlir::Type opBaseType = isa<VectorType>(opType)
3200 ? cast<VectorType>(opType).getElementType()
3201 : opType;
3202
3203 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3204 return op.emitOpError("FTZ and saturation are not supported for "
3205 "additions/subtractions involving f64 type");
3206
3207 if (opBaseType.isF16() && !(rndMode == NVVM::FPRoundingMode::RN ||
3208 rndMode == NVVM::FPRoundingMode::NONE))
3209 return op.emitOpError("only RN rounding mode is supported for f16 and "
3210 "vector<2xf16> additions/subtractions");
3211
3212 if (opBaseType.isBF16()) {
3213 if (rndMode != NVVM::FPRoundingMode::RN &&
3214 rndMode != NVVM::FPRoundingMode::NONE)
3215 return op.emitOpError("only RN rounding mode is supported for bf16 and "
3216 "vector<2xbf16> additions/subtractions");
3217 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3218 return op.emitOpError("FTZ and saturation are not supported for bf16 and "
3219 "vector<2xbf16> additions/subtractions");
3220 }
3221
3222 // FIXME: This is a temporary check disallowing lowering to add.rn.ftz.f16(x2)
3223 // PTX instructions since the corresponding LLVM intrinsic is missing. This
3224 // should be removed once the intrinsics for f16 addition (with FTZ only) are
3225 // available.
3226 if (opBaseType.isF16() && isFTZ && satMode == NVVM::SaturationMode::NONE)
3227 return op.emitOpError("FTZ with no saturation is not supported for f16 and "
3228 "vector<2xf16> additions/subtractions");
3229
3230 return success();
3231}
3232
3233LogicalResult NVVM::AddFOp::verify() { return verifyAddSubFOp<AddFOp>(*this); }
3234
3235LogicalResult NVVM::SubFOp::verify() { return verifyAddSubFOp<SubFOp>(*this); }
3236
3237LogicalResult NVVM::FmaOp::verify() {
3238 auto opType = getRes().getType();
3239 mlir::NVVM::FPRoundingMode rndMode = getRnd();
3240 mlir::NVVM::SaturationMode satMode = getSat();
3241 bool isFTZ = getFtz();
3242 bool isRelu = getRelu();
3243 bool hasOOB = getOob();
3244
3245 auto getBaseFType = [](Type type) -> Type {
3246 if (isa<VectorType>(type))
3247 return cast<VectorType>(type).getElementType();
3248 return type;
3249 };
3250
3251 auto opBaseType = getBaseFType(opType);
3252
3253 if (rndMode == NVVM::FPRoundingMode::NONE)
3254 return emitOpError("rounding mode must be specified");
3255
3256 if (isRelu && satMode == NVVM::SaturationMode::SAT)
3257 return emitOpError("relu and saturation are not supported together");
3258
3259 if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
3260 return emitOpError("oob is not supported with saturation or FTZ");
3261
3262 if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
3263 return emitOpError("relu and oob are only supported for f16 and bf16");
3264
3265 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3266 return emitOpError("FTZ and saturation are not supported for f64 type");
3267
3268 if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
3269 return emitOpError(
3270 "only RN rounding mode is supported for f16 and vector<2xf16>");
3271
3272 if (opBaseType.isBF16()) {
3273 if (rndMode != NVVM::FPRoundingMode::RN)
3274 return emitOpError(
3275 "only RN rounding mode is supported for bf16 and vector<2xbf16>");
3276 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3277 return emitOpError(
3278 "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
3279 }
3280
3281 return success();
3282}
3283
3284LogicalResult NVVM::SqrtOp::verify() {
3285 if (getRnd() == NVVM::FPRoundingMode::NONE)
3286 return emitOpError("rounding mode cannot be None");
3287
3288 if (getRes().getType().isF64() && getFtz())
3289 return emitOpError("FTZ is not supported for f64");
3290
3291 return success();
3292}
3293
3294LogicalResult NVVM::DivFOp::verify() {
3295 bool isApprox = getApprox();
3296 bool isFull = getFull();
3297 bool isF64 = getRes().getType().isF64();
3298 bool isFtz = getFtz();
3299 NVVM::FPRoundingMode rndMode = getRnd();
3300
3301 if (isApprox && isFull)
3302 return emitOpError("'approx' and 'full' are mutually exclusive");
3303
3304 if (isApprox || isFull) {
3305 if (isF64)
3306 return emitOpError("'approx' and 'full' forms are f32-only");
3307 if (rndMode != NVVM::FPRoundingMode::NONE)
3308 return emitOpError(
3309 "'approx' and 'full' forms do not accept a rounding mode");
3310 return success();
3311 }
3312
3313 // Rounded form below.
3314 if (rndMode == NVVM::FPRoundingMode::NONE)
3315 return emitOpError("rounding mode cannot be None for the rounded divide");
3316 if (isF64 && isFtz)
3317 return emitOpError("FTZ is not supported for f64");
3318
3319 return success();
3320}
3321
3322/// Packs the given `field` into the `result`.
3323/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
3324static llvm::Value *
3325packValInto64Bits(llvm::IRBuilderBase &builder,
3326 llvm::Value *result, // the `result` (unset bits are zero)
3327 llvm::Value *field, // `field` to pack into `result`
3328 unsigned sizeInBits, // Size of `field` in bits
3329 unsigned start) { // Starting bit within `result`
3330 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3331
3332 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3333 if (mask != 0xffffffffu)
3334 field = builder.CreateAnd(field, builder.getInt32(mask));
3335
3336 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3337 field = builder.CreateShl(field, start);
3338
3339 return builder.CreateOr(result, field);
3340}
3341
3342void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
3344 llvm::IRBuilderBase &builder) {
3345 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3346 llvm::Value *smemDesc = builder.getInt64(0);
3347
3348 smemDesc = packValInto64Bits(builder, smemDesc,
3349 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
3350 smemDesc = packValInto64Bits(
3351 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3352 smemDesc = packValInto64Bits(
3353 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3354
3355 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
3356 smemDesc = packValInto64Bits(builder, smemDesc,
3357 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
3358 smemDesc = packValInto64Bits(
3359 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3360 smemDesc = packValInto64Bits(builder, smemDesc,
3361 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
3362
3363 mt.mapValue(thisOp.getRes()) = smemDesc;
3364}
3365
3366//===----------------------------------------------------------------------===//
3367// getPtx methods
3368//===----------------------------------------------------------------------===//
3369
3370std::string NVVM::MBarrierInitOp::getPtx() {
3371 bool isShared = isPtrInSharedCTASpace(getAddr());
3372 return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
3373 : std::string("mbarrier.init.b64 [%0], %1;");
3374}
3375
3376std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3377 bool isShared = isPtrInSharedCTASpace(getAddr());
3378 return isShared
3379 ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3380 : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3381}
3382
3383std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3384 bool isShared = isPtrInSharedCTASpace(getAddr());
3385 llvm::StringRef space = isShared ? ".shared" : "";
3386
3387 return llvm::formatv("{\n\t"
3388 ".reg .pred P1; \n\t"
3389 "LAB_WAIT: \n\t"
3390 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3391 "@P1 bra.uni DONE; \n\t"
3392 "bra.uni LAB_WAIT; \n\t"
3393 "DONE: \n\t"
3394 "}",
3395 space);
3396}
3397
3398//===----------------------------------------------------------------------===//
3399// Canonicalization patterns
3400//===----------------------------------------------------------------------===//
3401
3404
3405 LogicalResult matchAndRewrite(SubFOp op,
3406 PatternRewriter &rewriter) const override {
3407 Location loc = op.getLoc();
3408 Value negRhs =
3409 LLVM::FNegOp::create(rewriter, loc, op.getRhs().getType(), op.getRhs());
3410
3411 rewriter.replaceOpWithNewOp<AddFOp>(op, op.getType(), op.getLhs(), negRhs,
3412 op.getRnd(), op.getSat(), op.getFtz());
3413 return success();
3414 }
3415};
3416
3417void SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3418 MLIRContext *context) {
3419 patterns.add<ConvertFsubToFnegFadd>(context);
3420}
3421
3422//===----------------------------------------------------------------------===//
3423// getIntrinsicID/getIntrinsicIDAndArgs methods
3424//===----------------------------------------------------------------------===//
3425
3426/// Maps the (aligned, hasCount) pair to the `@llvm.nvvm.barrier.cta.sync.*`
3427/// intrinsic ID.
3428static llvm::Intrinsic::ID getBarrierSyncIntrinsic(bool aligned,
3429 bool hasCount) {
3430 if (hasCount) {
3431 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count
3432 : llvm::Intrinsic::nvvm_barrier_cta_sync_count;
3433 }
3434 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all
3435 : llvm::Intrinsic::nvvm_barrier_cta_sync_all;
3436}
3437
3438/// Maps the (aligned, kind) pair to the `@llvm.nvvm.barrier.cta.red.*`
3439/// intrinsic ID.
3440static llvm::Intrinsic::ID
3441getBarrierReductionIntrinsic(bool aligned, NVVM::BarrierReduction kind) {
3442 switch (kind) {
3443 case NVVM::BarrierReduction::AND:
3444 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all
3445 : llvm::Intrinsic::nvvm_barrier_cta_red_and_all;
3446 case NVVM::BarrierReduction::OR:
3447 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all
3448 : llvm::Intrinsic::nvvm_barrier_cta_red_or_all;
3449 case NVVM::BarrierReduction::POPC:
3450 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all
3451 : llvm::Intrinsic::nvvm_barrier_cta_red_popc_all;
3452 }
3453 llvm_unreachable("unknown BarrierReduction kind");
3454}
3455
3456mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
3457 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3458 auto thisOp = cast<NVVM::BarrierOp>(op);
3459 llvm::Value *barrierId = thisOp.getBarrierId()
3460 ? mt.lookupValue(thisOp.getBarrierId())
3461 : builder.getInt32(0);
3462 bool hasCount = static_cast<bool>(thisOp.getNumberOfThreads());
3463 llvm::Intrinsic::ID id =
3464 getBarrierSyncIntrinsic(thisOp.getAligned(), hasCount);
3465 llvm::SmallVector<llvm::Value *> args = {barrierId};
3466 if (hasCount)
3467 args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
3468 return {id, std::move(args)};
3469}
3470
3471mlir::NVVM::IDArgPair NVVM::BarrierArriveOp::getIntrinsicIDAndArgs(
3472 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3473 auto thisOp = cast<NVVM::BarrierArriveOp>(op);
3474 llvm::Value *barrierId = thisOp.getBarrierId()
3475 ? mt.lookupValue(thisOp.getBarrierId())
3476 : builder.getInt32(0);
3477 llvm::Value *numThreads = mt.lookupValue(thisOp.getNumberOfThreads());
3478 llvm::Intrinsic::ID id =
3479 thisOp.getAligned()
3480 ? llvm::Intrinsic::nvvm_barrier_cta_arrive_aligned_count
3481 : llvm::Intrinsic::nvvm_barrier_cta_arrive_count;
3482 return {id, {barrierId, numThreads}};
3483}
3484
3485mlir::NVVM::IDArgPair NVVM::BarrierReductionOp::getIntrinsicIDAndArgs(
3486 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3487 auto thisOp = cast<NVVM::BarrierReductionOp>(op);
3488 llvm::Intrinsic::ID id = getBarrierReductionIntrinsic(
3489 thisOp.getAligned(), thisOp.getReductionOp());
3490 llvm::Value *barrierId = thisOp.getBarrierId()
3491 ? mt.lookupValue(thisOp.getBarrierId())
3492 : builder.getInt32(0);
3494 barrierId,
3495 builder.CreateICmpNE(mt.lookupValue(thisOp.getReductionPredicate()),
3496 builder.getInt32(0))};
3497 return {id, std::move(args)};
3498}
3499
3501CosOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3502 llvm::IRBuilderBase &builder) {
3503 auto thisOp = cast<NVVM::CosOp>(op);
3504 llvm::Intrinsic::ID id = thisOp.getFtz()
3505 ? llvm::Intrinsic::nvvm_cos_approx_ftz_f
3506 : llvm::Intrinsic::nvvm_cos_approx_f;
3507 return {id, {mt.lookupValue(thisOp.getSrc())}};
3508}
3509
3511SinOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3512 llvm::IRBuilderBase &builder) {
3513 auto thisOp = cast<NVVM::SinOp>(op);
3514 llvm::Intrinsic::ID id = thisOp.getFtz()
3515 ? llvm::Intrinsic::nvvm_sin_approx_ftz_f
3516 : llvm::Intrinsic::nvvm_sin_approx_f;
3517 return {id, {mt.lookupValue(thisOp.getSrc())}};
3518}
3519
3521Log2Op::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3522 llvm::IRBuilderBase &builder) {
3523 auto thisOp = cast<NVVM::Log2Op>(op);
3524 llvm::Intrinsic::ID id = thisOp.getFtz()
3525 ? llvm::Intrinsic::nvvm_lg2_approx_ftz_f
3526 : llvm::Intrinsic::nvvm_lg2_approx_f;
3527 return {id, {mt.lookupValue(thisOp.getSrc())}};
3528}
3529
3531Ex2Op::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3532 llvm::IRBuilderBase &builder) {
3533 auto thisOp = cast<NVVM::Ex2Op>(op);
3534 llvm::Intrinsic::ID id = thisOp.getFtz()
3535 ? llvm::Intrinsic::nvvm_ex2_approx_ftz
3536 : llvm::Intrinsic::nvvm_ex2_approx;
3537 return {id, {mt.lookupValue(thisOp.getSrc())}};
3538}
3539
3541RsqrtOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3542 llvm::IRBuilderBase &builder) {
3543 auto thisOp = cast<NVVM::RsqrtOp>(op);
3544 Type t = thisOp.getRes().getType();
3545 bool isFtz = thisOp.getFtz();
3546
3547 llvm::Intrinsic::ID id = [&] {
3548 if (t.isF32()) {
3549 return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_f
3550 : llvm::Intrinsic::nvvm_rsqrt_approx_f;
3551 }
3552 // f64
3553 return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_d
3554 : llvm::Intrinsic::nvvm_rsqrt_approx_d;
3555 }();
3556
3557 return {id, {mt.lookupValue(thisOp.getSrc())}};
3558}
3559
3561SqrtOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3562 llvm::IRBuilderBase &builder) {
3563 auto thisOp = cast<NVVM::SqrtOp>(op);
3564 Type t = thisOp.getRes().getType();
3565 NVVM::FPRoundingMode rndMode = thisOp.getRnd();
3566 bool isFtz = thisOp.getFtz();
3567
3568 // RM is one of RN/RM/RP/RZ (verifier rejects NONE).
3569 // Subtracting 1 maps RN=1..RZ=4 to 0..3.
3570 unsigned rndIndex = static_cast<unsigned>(rndMode) - 1;
3571
3572 static constexpr llvm::Intrinsic::ID f32IDs[] = {
3573 llvm::Intrinsic::nvvm_sqrt_rn_f,
3574 llvm::Intrinsic::nvvm_sqrt_rm_f,
3575 llvm::Intrinsic::nvvm_sqrt_rp_f,
3576 llvm::Intrinsic::nvvm_sqrt_rz_f,
3577 };
3578 static constexpr llvm::Intrinsic::ID f32FTZIDs[] = {
3579 llvm::Intrinsic::nvvm_sqrt_rn_ftz_f,
3580 llvm::Intrinsic::nvvm_sqrt_rm_ftz_f,
3581 llvm::Intrinsic::nvvm_sqrt_rp_ftz_f,
3582 llvm::Intrinsic::nvvm_sqrt_rz_ftz_f,
3583 };
3584 static constexpr llvm::Intrinsic::ID f64IDs[] = {
3585 llvm::Intrinsic::nvvm_sqrt_rn_d,
3586 llvm::Intrinsic::nvvm_sqrt_rm_d,
3587 llvm::Intrinsic::nvvm_sqrt_rp_d,
3588 llvm::Intrinsic::nvvm_sqrt_rz_d,
3589 };
3590
3591 llvm::Intrinsic::ID id =
3592 t.isF32() ? (isFtz ? f32FTZIDs[rndIndex] : f32IDs[rndIndex])
3593 : f64IDs[rndIndex];
3594
3595 return {id, {mt.lookupValue(thisOp.getSrc())}};
3596}
3597
3599SqrtApproxOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3600 llvm::IRBuilderBase &builder) {
3601 auto thisOp = cast<NVVM::SqrtApproxOp>(op);
3602 llvm::Intrinsic::ID id = thisOp.getFtz()
3603 ? llvm::Intrinsic::nvvm_sqrt_approx_ftz_f
3604 : llvm::Intrinsic::nvvm_sqrt_approx_f;
3605 return {id, {mt.lookupValue(thisOp.getSrc())}};
3606}
3607
3609DivFOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3610 llvm::IRBuilderBase &builder) {
3611 auto thisOp = cast<NVVM::DivFOp>(op);
3612 bool isFtz = thisOp.getFtz();
3613
3614 llvm::Intrinsic::ID id;
3615
3616 if (thisOp.getApprox()) {
3617 id = isFtz ? llvm::Intrinsic::nvvm_div_approx_ftz_f
3618 : llvm::Intrinsic::nvvm_div_approx_f;
3619 } else if (thisOp.getFull()) {
3620 // Intrinsic Naming quirk: int_nvvm_div_full has no `_f` suffix (unlike
3621 // approx).
3622 id = isFtz ? llvm::Intrinsic::nvvm_div_full_ftz
3623 : llvm::Intrinsic::nvvm_div_full;
3624 } else {
3625 // Rounded form — three 4-entry tables indexed by (rndMode - 1).
3626 unsigned rndIndex = static_cast<unsigned>(thisOp.getRnd()) - 1;
3627
3628 static constexpr llvm::Intrinsic::ID f32IDs[] = {
3629 llvm::Intrinsic::nvvm_div_rn_f,
3630 llvm::Intrinsic::nvvm_div_rm_f,
3631 llvm::Intrinsic::nvvm_div_rp_f,
3632 llvm::Intrinsic::nvvm_div_rz_f,
3633 };
3634 static constexpr llvm::Intrinsic::ID f32FTZIDs[] = {
3635 llvm::Intrinsic::nvvm_div_rn_ftz_f,
3636 llvm::Intrinsic::nvvm_div_rm_ftz_f,
3637 llvm::Intrinsic::nvvm_div_rp_ftz_f,
3638 llvm::Intrinsic::nvvm_div_rz_ftz_f,
3639 };
3640 static constexpr llvm::Intrinsic::ID f64IDs[] = {
3641 llvm::Intrinsic::nvvm_div_rn_d,
3642 llvm::Intrinsic::nvvm_div_rm_d,
3643 llvm::Intrinsic::nvvm_div_rp_d,
3644 llvm::Intrinsic::nvvm_div_rz_d,
3645 };
3646 Type t = thisOp.getRes().getType();
3647 id = t.isF32() ? (isFtz ? f32FTZIDs[rndIndex] : f32IDs[rndIndex])
3648 : f64IDs[rndIndex];
3649 }
3650
3651 return {id,
3652 {mt.lookupValue(thisOp.getLhs()), mt.lookupValue(thisOp.getRhs())}};
3653}
3654
3656PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3657 llvm::IRBuilderBase &builder) {
3658 auto thisOp = cast<NVVM::PMEventOp>(op);
3659 llvm::Type *i16Ty = llvm::Type::getInt16Ty(mt.getLLVMContext());
3660
3661 // With event-id, mask is generated as (1 << event-id)
3662 llvm::Value *maskVal;
3663 if (auto eventAttr = thisOp.getEventIdAttr()) {
3664 uint16_t mask = static_cast<uint16_t>(1u << eventAttr.getInt());
3665 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3666 } else {
3667 maskVal =
3668 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3669 }
3670
3671 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3672}
3673
3674mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
3675 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3676 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3677 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3678 llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3679 : llvm::Intrinsic::nvvm_mbarrier_init;
3680
3681 // Fill the Intrinsic Args
3683 args.push_back(mt.lookupValue(thisOp.getAddr()));
3684 args.push_back(mt.lookupValue(thisOp.getCount()));
3685
3686 return {id, std::move(args)};
3687}
3688
3689mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
3690 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3691 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3692 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3693 llvm::Intrinsic::ID id = isShared
3694 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3695 : llvm::Intrinsic::nvvm_mbarrier_inval;
3696
3697 return {id, {mt.lookupValue(thisOp.getAddr())}};
3698}
3699
3700mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
3701 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3702 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3703
3704 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3705 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3706 // bit-0: Space
3707 // bit-1: Scope
3708 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3709
3710 static constexpr llvm::Intrinsic::ID IDs[] = {
3711 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3712 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3713 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3714 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3715
3716 // Fill the Intrinsic Args
3718 args.push_back(mt.lookupValue(thisOp.getAddr()));
3719 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3720
3721 return {IDs[index], std::move(args)};
3722}
3723
3724mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
3725 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3726 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3727
3728 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3729 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3730 // bit-0: Space
3731 // bit-1: Scope
3732 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3733
3734 static constexpr llvm::Intrinsic::ID IDs[] = {
3735 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3736 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3737 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3738 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3739
3740 // Fill the Intrinsic Args
3742 args.push_back(mt.lookupValue(thisOp.getAddr()));
3743 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3744
3745 return {IDs[index], std::move(args)};
3746}
3747
3748mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
3749 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3750 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3751
3752 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3753 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3754 // bit-0: Space
3755 // bit-1: Scope
3756 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3757
3758 static constexpr llvm::Intrinsic::ID IDs[] = {
3759 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3760 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3761 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3762 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3763 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3764 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3765 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3766 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3767 llvm::Intrinsic::
3768 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3769 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3770
3771 // Tidy-up the Intrinsic Args
3772 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3773 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3774 if (needCast)
3775 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3776
3777 // We have the most basic mbarrier.arrive supported on sm_80.
3778 // It supports: Space=cta, scope=cta, No relaxed, No explicit count.
3779 // So, only for this combination use the legacy intrinsic.
3780 bool hasCount = static_cast<bool>(thisOp.getCount());
3781 if (!hasCount &&
3782 (id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3783 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3784
3785 // When count is not explicitly specified, the default is 1.
3786 llvm::LLVMContext &ctx = mt.getLLVMContext();
3787 llvm::Value *count =
3788 hasCount ? mt.lookupValue(thisOp.getCount())
3789 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3790 return {id, {mbar, count}};
3791}
3792
3793mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
3794 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3795 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3796
3797 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3798 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3799 // bit-0: Space
3800 // bit-1: Scope
3801 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3802
3803 static constexpr llvm::Intrinsic::ID IDs[] = {
3804 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3805 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3806 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3807 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3808 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3809 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3810 llvm::Intrinsic::
3811 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3812 llvm::Intrinsic::
3813 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3814 llvm::Intrinsic::
3815 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3816 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3817
3818 // Tidy-up the Intrinsic Args
3819 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3820 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3821 if (needCast)
3822 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3823
3824 // When count is not explicitly specified, the default is 1.
3825 llvm::LLVMContext &ctx = mt.getLLVMContext();
3826 bool hasCount = static_cast<bool>(thisOp.getCount());
3827 llvm::Value *count =
3828 hasCount ? mt.lookupValue(thisOp.getCount())
3829 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3830
3831 return {id, {mbar, count}};
3832}
3833
3834bool MBarrierArriveExpectTxOp::getAsmValues(
3835 RewriterBase &rewriter,
3836 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3837 &asmValues) {
3838 // Add all the operands but not the attrs to the asmValues list.
3839 // The attrs here are used to generate the right variants for
3840 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3841 for (auto val : getOperands())
3842 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3843
3844 return false;
3845}
3846
3847mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
3848 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3849 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3850
3851 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3852 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3853 // bit-0: Space
3854 // bit-1: Scope
3855 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3856
3857 // clang-format off
3858 static constexpr llvm::Intrinsic::ID IDs[] = {
3859 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3860 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3861 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3862 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3863 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3864 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3865 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3866 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3867 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3868 // clang-format on
3869 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3870
3871 // Tidy-up the Intrinsic Args
3872 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3873 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3874 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3875 if (needCast)
3876 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3877
3878 return {id, {mbar, txcount}};
3879}
3880
3881mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
3882 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3883 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3884
3885 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3886 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3887 // bit-0: Space
3888 // bit-1: Scope
3889 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3890
3891 // clang-format off
3892 static constexpr llvm::Intrinsic::ID IDs[] = {
3893 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3894 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3895 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3896 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3897 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3898 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3899 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3900 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3901 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3902 // clang-format on
3903 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3904
3905 // Tidy-up the Intrinsic Args
3906 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3907 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3908 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3909 if (needCast)
3910 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3911
3912 return {id, {mbar, txcount}};
3913}
3914
3915mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
3916 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3917 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3918 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3919 llvm::Intrinsic::ID id =
3920 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3921 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3922 // Fill the Intrinsic Args
3924 args.push_back(mt.lookupValue(thisOp.getAddr()));
3925 args.push_back(mt.lookupValue(thisOp.getCount()));
3926
3927 return {id, std::move(args)};
3928}
3929
3930mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
3931 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3932 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3933 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3934 llvm::Intrinsic::ID id =
3935 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3936 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3937 // Fill the Intrinsic Args
3939 args.push_back(mt.lookupValue(thisOp.getAddr()));
3940 args.push_back(mt.lookupValue(thisOp.getCount()));
3941
3942 return {id, std::move(args)};
3943}
3944
3945mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
3946 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3947 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3948 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3949 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3950 // bit-0: isPhaseParity
3951 // bit-1: Scope
3952 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3953
3954 // clang-format off
3955 static constexpr llvm::Intrinsic::ID IDs[] = {
3956 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3957 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3958 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3959 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3960 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3961 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3962 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3963 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3964 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3965 // clang-format on
3966 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3967
3968 // Tidy-up the Intrinsic Args
3969 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3970 llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
3971 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3972 if (needCast)
3973 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3974
3975 return {id, {mbar, input}};
3976}
3977
3978mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
3979 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3980 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3981 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3982 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3983 bool hasTicks = static_cast<bool>(thisOp.getTicks());
3984 // bit-0: isPhaseParity
3985 // bit-1: Scope
3986 // bit-2: hasTicks
3987 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3988 (isPhaseParity ? 1 : 0);
3989
3990 // clang-format off
3991 static constexpr llvm::Intrinsic::ID IDs[] = {
3992 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3993 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3994 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3995 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3996 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3997 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3998 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3999 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
4000 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
4001 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
4002 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
4003 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
4004 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
4005 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
4006 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
4007 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
4008 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
4009 // clang-format on
4010 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
4011
4012 // Tidy-up the mbarrier pointer
4013 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
4014 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
4015 if (needCast)
4016 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
4017
4018 // Fill the Intrinsic Args
4020 args.push_back(mbar);
4021 args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
4022 if (hasTicks)
4023 args.push_back(mt.lookupValue(thisOp.getTicks()));
4024
4025 return {id, std::move(args)};
4026}
4027
4028mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
4029 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4030 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
4031 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
4032
4033 llvm::Intrinsic::ID id;
4034 if (thisOp.getNoinc()) {
4035 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
4036 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
4037 } else {
4038 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
4039 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
4040 }
4041
4042 return {id, {mt.lookupValue(thisOp.getAddr())}};
4043}
4044
4046MovMatrixOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4047 llvm::IRBuilderBase &builder) {
4048 auto thisOp = cast<NVVM::MovMatrixOp>(op);
4049 return {llvm::Intrinsic::nvvm_movmatrix_sync_aligned_m8n8_trans_b16,
4050 {mt.lookupValue(thisOp.getSrc())}};
4051}
4052
4053#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
4054 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
4055
4056#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
4057 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
4058
4059llvm::Intrinsic::ID
4060CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
4062 llvm::Intrinsic::ID id;
4063
4064 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
4065 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
4066 switch (cpAsyncOp.getSize()) {
4067 case 4:
4068 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
4069 break;
4070 case 8:
4071 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
4072 break;
4073 case 16:
4074 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
4075 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
4076 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
4077 break;
4078 default:
4079 llvm_unreachable("Invalid copy size in CpAsyncOp.");
4080 }
4081
4082 // Fill the Intrinsic Args
4083 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
4084 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
4085 if (hasCpSize)
4086 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
4087
4088 return id;
4089}
4090
4091mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
4092 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4093 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
4095 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
4096
4097 // Fill the Intrinsic Args
4098 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4099 args.push_back(mt.lookupValue(thisOp.getSize()));
4100
4101 mlir::Value cacheHint = thisOp.getL2CacheHint();
4102 const bool hasCacheHint = static_cast<bool>(cacheHint);
4103 llvm::Value *i64Unused =
4104 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4105 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4106 args.push_back(builder.getInt1(hasCacheHint));
4107
4108 return {id, std::move(args)};
4109}
4110
4111mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
4112 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4113 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
4115
4116 // Fill the Intrinsic Args: dst, mbar, src, size.
4117 args.push_back(mt.lookupValue(thisOp.getDstMem()));
4118 args.push_back(mt.lookupValue(thisOp.getMbar()));
4119 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4120 args.push_back(mt.lookupValue(thisOp.getSize()));
4121
4122 // Multicast mask for shared::cluster only, if available.
4123 mlir::Value multicastMask = thisOp.getMulticastMask();
4124 const bool hasMulticastMask = static_cast<bool>(multicastMask);
4125 const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
4126 if (!isSharedCTA) {
4127 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
4128 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
4129 : i16Unused);
4130 }
4131
4132 // Cache hint, if available.
4133 mlir::Value cacheHint = thisOp.getL2CacheHint();
4134 const bool hasCacheHint = static_cast<bool>(cacheHint);
4135 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
4136 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4137
4138 // Flag arguments for multicast and cachehint.
4139 if (!isSharedCTA)
4140 args.push_back(builder.getInt1(hasMulticastMask));
4141 args.push_back(builder.getInt1(hasCacheHint));
4142
4143 llvm::Intrinsic::ID id =
4144 isSharedCTA
4145 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
4146 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
4147
4148 return {id, std::move(args)};
4149}
4150
4151mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
4152 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4153 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
4155 llvm::Intrinsic::ID id =
4156 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
4157
4158 // Fill the Intrinsic Args
4159 args.push_back(mt.lookupValue(thisOp.getDstMem()));
4160 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4161 args.push_back(mt.lookupValue(thisOp.getSize()));
4162
4163 mlir::Value cacheHint = thisOp.getL2CacheHint();
4164 const bool hasCacheHint = static_cast<bool>(cacheHint);
4165 llvm::Value *i64Unused =
4166 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4167 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4168 args.push_back(builder.getInt1(hasCacheHint));
4169
4170 // Choose the bytemask variant
4171 if (mlir::Value byteMask = thisOp.getByteMask()) {
4172 args.push_back(mt.lookupValue(byteMask));
4173 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
4174 }
4175
4176 return {id, std::move(args)};
4177}
4178
4179bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
4180 RewriterBase &rewriter,
4181 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
4182 &asmValues) {
4183 // Add all the operands but not the attrs to the asmValues list.
4184 // The attrs here are used to generate the right variants for
4185 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
4186 for (auto val : getOperands())
4187 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
4188
4189 return false;
4190}
4191
4193CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
4194 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4195 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
4196 const bool isCTAOnly = thisOp.getIsCTAOnly();
4198
4199 // Fill the Intrinsic Args
4200 args.push_back(mt.lookupValue(thisOp.getDstMem()));
4201 args.push_back(mt.lookupValue(thisOp.getMbar()));
4202 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4203
4204 // Coordinates and im2col-offsets
4205 for (mlir::Value v : thisOp.getCoordinates())
4206 args.push_back(mt.lookupValue(v));
4207 for (mlir::Value v : thisOp.getIm2colOffsets())
4208 args.push_back(mt.lookupValue(v));
4209
4210 // MulticastMask, if available
4211 mlir::Value mcMask = thisOp.getMulticastMask();
4212 const bool hasMC = static_cast<bool>(mcMask);
4213 llvm::Value *i16Zero =
4214 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
4215
4216 // CacheHint, if available
4217 mlir::Value cacheHint = thisOp.getL2CacheHint();
4218 const bool hasCacheHint = static_cast<bool>(cacheHint);
4219 llvm::Value *i64Zero =
4220 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4221
4222 // Flag argument CTAGroup
4223 // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
4224 // Hence, the +1 to getGroup().
4225 const int32_t val =
4226 thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
4227 llvm::Value *cg =
4228 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
4229
4230 if (!isCTAOnly) {
4231 // For shared::cluster, all the arguments that we build are applicable.
4232 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
4233 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
4234 args.push_back(builder.getInt1(hasMC));
4235 args.push_back(builder.getInt1(hasCacheHint));
4236 args.push_back(cg);
4237 } else {
4238 // For shared::cta, only cache-hint is applicable.
4239 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
4240 args.push_back(builder.getInt1(hasCacheHint));
4241 }
4242
4243 constexpr size_t numDims = 5; // 1D to 5D
4244 constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
4245 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
4246 using TableTy = std::array<rowTy, numModes>;
4247 static constexpr TableTy IDTable{
4248 {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
4249 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
4250 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
4251 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
4252 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
4254 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
4255 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
4256 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
4258 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
4259 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
4260 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
4262 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
4263 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
4264 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
4266 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
4267
4268 static constexpr TableTy IDTableCTA{
4269 {{notIntrinsic,
4270 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
4271 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
4272 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
4273 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
4274 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
4276 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
4277 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
4278 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
4280 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
4281 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
4282 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
4284 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
4285 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
4286 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
4288 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
4289
4290 static_assert(
4291 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
4292 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
4293 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
4294 size_t mode = static_cast<size_t>(thisOp.getMode());
4295 size_t dim = thisOp.getCoordinates().size();
4296 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
4297 assert(id != notIntrinsic &&
4298 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
4299
4300 return {id, std::move(args)};
4301}
4302
4303mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
4304 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4305 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
4307
4308 // Fill the Intrinsic Args
4309 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4310
4311 for (auto v : thisOp.getCoordinates())
4312 args.push_back(mt.lookupValue(v));
4313 for (auto v : thisOp.getIm2colOffsets())
4314 args.push_back(mt.lookupValue(v));
4315
4316 mlir::Value cacheHint = thisOp.getL2CacheHint();
4317 const bool hasCacheHint = static_cast<bool>(cacheHint);
4318 llvm::Value *i64Unused =
4319 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4320 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4321 args.push_back(builder.getInt1(hasCacheHint));
4322
4323 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4324 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4325 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
4326 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
4327 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
4328 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
4329 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
4330 {NI, NI, NI,
4331 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
4332 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
4333 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
4334 {NI, NI, NI,
4335 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
4336 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
4337 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
4338 {NI, NI, NI,
4339 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
4340 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
4341 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
4342 {NI, NI, NI, NI, NI,
4343 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
4344
4345 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
4346 "TMALoadModes must match number of rows in IDTable");
4347 size_t mode = static_cast<size_t>(thisOp.getMode());
4348 size_t dim = thisOp.getCoordinates().size();
4349 llvm::Intrinsic::ID id = IDTable[mode][dim];
4350 if (id == llvm::Intrinsic::not_intrinsic)
4351 llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
4352
4353 return {id, std::move(args)};
4354}
4355
4357CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
4358 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4359 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
4361
4362 // Fill the Intrinsic Args
4363 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4364 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4365
4366 for (auto v : thisOp.getCoordinates())
4367 args.push_back(mt.lookupValue(v));
4368
4369 mlir::Value cacheHint = thisOp.getL2CacheHint();
4370 const bool hasCacheHint = static_cast<bool>(cacheHint);
4371 llvm::Value *i64Unused =
4372 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4373 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4374 args.push_back(builder.getInt1(hasCacheHint));
4375
4376 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4377 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4378 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
4379 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
4380 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
4381 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
4382 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
4383 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
4384 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
4385 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
4386 {NI, NI, NI, NI, NI,
4387 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
4388
4389 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
4390 "TMAStoreModes must match number of rows in IDTable");
4391 size_t mode = static_cast<size_t>(thisOp.getMode());
4392 size_t dim = thisOp.getCoordinates().size();
4393 llvm::Intrinsic::ID id = IDTable[mode][dim];
4394 if (id == llvm::Intrinsic::not_intrinsic)
4395 llvm_unreachable(
4396 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
4397
4398 return {id, std::move(args)};
4399}
4400
4401NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
4402 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4403 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
4404 llvm::LLVMContext &ctx = mt.getLLVMContext();
4405
4407
4408 // Arguments to the intrinsic:
4409 // shared_mem_ptr, tmaDesc, tensorDims
4410 // cache_hint(if applicable) and flag(boolean)
4411 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4412 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4413
4414 for (Value v : thisOp.getCoordinates())
4415 args.push_back(mt.lookupValue(v));
4416
4417 mlir::Value cacheHint = thisOp.getL2CacheHint();
4418 const bool hasCacheHint = static_cast<bool>(cacheHint);
4419 llvm::Value *i64ZeroValue =
4420 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
4421 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
4422 args.push_back(builder.getInt1(hasCacheHint));
4423
4424 const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
4425
4426 constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
4427 constexpr unsigned numLayouts = 2; // TILE, IM2COL
4428 constexpr unsigned maxDim = 5; // 1D to 5D
4429 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
4430 using layoutTable = std::array<row, numLayouts>;
4431 using fullTable = std::array<layoutTable, numRedKinds>;
4432 static constexpr fullTable IDTable{
4433 {// RedTy::ADD
4434 {{{{notIntrinsic,
4435 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
4436 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
4437 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
4438 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
4439 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
4441 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
4442 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
4443 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
4444 // RedTy::MIN
4445 {{{{notIntrinsic,
4446 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
4447 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
4448 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
4449 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
4450 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
4452 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
4453 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
4454 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
4455 // RedTy::MAX
4456 {{{{notIntrinsic,
4457 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
4458 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
4459 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
4460 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
4461 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
4463 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
4464 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
4465 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
4466 // RedTy::INC
4467 {{{{notIntrinsic,
4468 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
4469 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
4470 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
4471 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
4472 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
4474 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
4475 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4476 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4477 // RedTy::DEC
4478 {{{{notIntrinsic,
4479 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4480 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4481 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4482 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4483 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4485 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4486 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4487 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4488 // RedTy::AND
4489 {{{{notIntrinsic,
4490 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4491 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4492 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4493 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4494 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4496 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4497 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4498 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4499 // RedTy::OR
4500 {{{{notIntrinsic,
4501 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4502 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4503 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4504 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4505 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4507 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4508 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4509 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4510 // RedTy::XOR
4511 {{{{notIntrinsic,
4512 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4513 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4514 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4515 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4516 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4518 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4519 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4520 llvm::Intrinsic::
4521 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4522
4523 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4524 "TMAReduxKinds must match number of rows in IDTable");
4525
4526 size_t redKind = static_cast<size_t>(thisOp.getRedKind());
4527 size_t mode = static_cast<size_t>(thisOp.getMode());
4528 size_t dim = thisOp.getCoordinates().size();
4529
4530 assert(redKind < IDTable.size() &&
4531 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4532 assert(mode < IDTable[redKind].size() &&
4533 "Invalid mode for CpAsyncBulkTensorReduceOp");
4534 assert(dim < IDTable[redKind][mode].size() &&
4535 "Invalid dim for CpAsyncBulkTensorReduceOp");
4536
4537 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4538
4539 assert(intrinsicID != notIntrinsic &&
4540 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4541
4542 return {intrinsicID, std::move(args)};
4543}
4544
4545#define _none
4546
4547#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4548 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4549 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4550
4551#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4552 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4553 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4554
4555llvm::Intrinsic::ID
4556ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4557 NVVM::SaturationMode sat, bool hasRelu) {
4558 using RndMode = NVVM::FPRoundingMode;
4559 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4560 switch (rnd) {
4561 case RndMode::RN:
4562 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
4563 case RndMode::RZ:
4564 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
4565 case RndMode::RNA:
4566 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
4567 default:
4568 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
4569 }
4570}
4571
4573ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4575 llvm::IRBuilderBase &builder) {
4577 args.push_back(mt.lookupValue(op.getA()));
4578 args.push_back(mt.lookupValue(op.getB()));
4579
4580 bool hasRelu = op.getRelu();
4581
4582 llvm::Intrinsic::ID intId =
4583 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4584 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4585
4586 return {intId, std::move(args)};
4587}
4588
4589#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4590 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4591 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4592
4593llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4594 bool hasRelu) {
4596 .Case([&](mlir::Float6E2M3FNType) {
4597 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
4598 })
4599 .Case([&](mlir::Float6E3M2FNType) {
4600 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
4601 })
4602 .Default([](mlir::Type) {
4603 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
4604 return llvm::Intrinsic::not_intrinsic;
4605 });
4606}
4607
4609ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
4611 llvm::IRBuilderBase &builder) {
4612 mlir::Type dstTy = op.getDstTy();
4613 bool hasRelu = op.getRelu();
4614
4615 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4616
4617 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4618 intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
4619 : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
4620
4622 args.push_back(mt.lookupValue(op.getSrc()));
4623
4624 return {intId, std::move(args)};
4625}
4626
4628ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
4630 llvm::IRBuilderBase &builder) {
4631 mlir::Type dstTy = op.getDstTy();
4632 bool hasRelu = op.getRelu();
4633
4634 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4635
4636 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4637 intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
4638 : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
4639
4641 args.push_back(mt.lookupValue(op.getSrc()));
4642
4643 return {intId, std::move(args)};
4644}
4645
4646llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4647 bool hasRelu) {
4649 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4650 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
4651 : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
4652 })
4653 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4654 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
4655 : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
4656 })
4657 .Default([](mlir::Type) {
4658 llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
4659 return llvm::Intrinsic::not_intrinsic;
4660 });
4661}
4662
4663llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4664 bool hasRelu) {
4666 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4667 return hasRelu
4668 ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
4669 : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
4670 })
4671 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4672 return hasRelu
4673 ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
4674 : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
4675 })
4676 .Default([](mlir::Type) {
4677 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
4678 return llvm::Intrinsic::not_intrinsic;
4679 });
4680}
4681
4682#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4683 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4684 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4685
4686#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4687 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4688 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4689
4690llvm::Intrinsic::ID
4691ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4692 NVVM::SaturationMode sat, bool hasRelu) {
4693 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4694 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4695 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4696
4698 .Case([&](mlir::Float8E4M3FNType) {
4699 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
4700 })
4701 .Case([&](mlir::Float8E5M2Type) {
4702 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
4703 })
4704 .Case([&](mlir::Float8E8M0FNUType) {
4705 if (hasRoundingModeRZ)
4706 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
4707 else if (hasRoundingModeRP)
4708 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
4709
4710 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4711 })
4712 .Default([](mlir::Type) {
4713 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4714 return llvm::Intrinsic::not_intrinsic;
4715 });
4716}
4717
4718#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4719 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4720 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4721
4722llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4723 bool hasRelu) {
4725 .Case([&](mlir::Float8E4M3FNType) {
4726 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
4727 })
4728 .Case([&](mlir::Float8E5M2Type) {
4729 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
4730 })
4731 .Default([](mlir::Type) {
4732 llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
4733 return llvm::Intrinsic::not_intrinsic;
4734 });
4735}
4736
4737llvm::Intrinsic::ID
4738ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4739 NVVM::FPRoundingMode rnd,
4740 NVVM::SaturationMode sat, bool hasRelu) {
4741 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4742
4743 static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
4744 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
4745 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
4746 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
4747 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
4748 };
4749
4751 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4752 return hasRelu
4753 ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
4754 : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
4755 })
4756 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4757 return hasRelu
4758 ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
4759 : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
4760 })
4761 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
4762 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4763 unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
4764 return ue8m0x2IDs[index];
4765 })
4766 .Default([](mlir::Type) {
4767 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
4768 return llvm::Intrinsic::not_intrinsic;
4769 });
4770}
4771
4772NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
4773 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4774 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4775
4776 bool hasRelu = curOp.getRelu();
4777
4778 llvm::Intrinsic::ID intId =
4780 .Case([&](Float8E4M3FNType type) {
4781 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4782 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4783 })
4784 .Case([&](Float8E5M2Type type) {
4785 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4786 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4787 })
4788 .Default([](mlir::Type type) {
4789 llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
4790 return llvm::Intrinsic::not_intrinsic;
4791 });
4792
4793 llvm::Value *packedI16 =
4794 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4795 llvm::Type::getInt16Ty(builder.getContext()));
4796
4797 return {intId, {packedI16}};
4798}
4799
4800NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
4801 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4802 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4803 bool hasScale = static_cast<bool>(curOp.getScaleFactor());
4804 bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
4805 bool hasRelu = curOp.getRelu();
4806
4807 static constexpr llvm::Intrinsic::ID E4M3Ids[] = {
4808 llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_scale_n2_ue8m0,
4809 llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4810 llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4811 llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4812 };
4813
4814 static constexpr llvm::Intrinsic::ID E5M2Ids[] = {
4815 llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_scale_n2_ue8m0,
4816 llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4817 llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4818 llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4819 };
4820
4821 llvm::Intrinsic::ID intId =
4823 .Case([&](Float8E8M0FNUType type) {
4824 return llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4825 })
4826 .Case([&](Float8E4M3FNType type) {
4827 return E4M3Ids[hasSatfinite << 1 | hasRelu];
4828 })
4829 .Case([&](Float8E5M2Type type) {
4830 return E5M2Ids[hasSatfinite << 1 | hasRelu];
4831 })
4832 .Default([](mlir::Type type) {
4833 llvm_unreachable("Invalid type for ConvertF8x2ToBF16x2Op");
4834 return llvm::Intrinsic::not_intrinsic;
4835 });
4836 llvm::Value *packedI16 =
4837 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4838 llvm::Type::getInt16Ty(builder.getContext()));
4839
4841 args.push_back(packedI16);
4842 if (!isa<Float8E8M0FNUType>(curOp.getSrcType()))
4843 args.push_back(
4844 hasScale ? mt.lookupValue(curOp.getScaleFactor())
4845 : builder.getInt16(0x7f7f)); // default scale factor (value of
4846 // 1 for both elements)
4847
4848 return {intId, std::move(args)};
4849}
4850
4851NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
4852 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4853 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4854
4855 bool hasRelu = curOp.getRelu();
4856
4857 llvm::Intrinsic::ID intId =
4859 .Case([&](Float6E2M3FNType type) {
4860 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4861 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4862 })
4863 .Case([&](Float6E3M2FNType type) {
4864 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4865 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4866 })
4867 .Default([](mlir::Type type) {
4868 llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
4869 return llvm::Intrinsic::not_intrinsic;
4870 });
4871
4872 llvm::Value *packedI16 =
4873 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4874 llvm::Type::getInt16Ty(builder.getContext()));
4875
4876 return {intId, {packedI16}};
4877}
4878
4879NVVM::IDArgPair ConvertF6x2ToBF16x2Op::getIntrinsicIDAndArgs(
4880 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4881 auto curOp = cast<NVVM::ConvertF6x2ToBF16x2Op>(op);
4882 bool hasScale = static_cast<bool>(curOp.getScaleFactor());
4883 bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
4884 bool hasRelu = curOp.getRelu();
4885
4886 static constexpr llvm::Intrinsic::ID E2M3Ids[] = {
4887 llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_scale_n2_ue8m0,
4888 llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4889 llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4890 llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4891 };
4892
4893 static constexpr llvm::Intrinsic::ID E3M2Ids[] = {
4894 llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_scale_n2_ue8m0,
4895 llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4896 llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4897 llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4898 };
4899
4900 unsigned idx = (hasSatfinite << 1) | hasRelu;
4901 llvm::Intrinsic::ID intId =
4903 .Case([&](Float6E2M3FNType type) { return E2M3Ids[idx]; })
4904 .Case([&](Float6E3M2FNType type) { return E3M2Ids[idx]; })
4905 .Default([](mlir::Type type) {
4906 llvm_unreachable("Invalid type for ConvertF6x2ToBF16x2Op");
4907 return llvm::Intrinsic::not_intrinsic;
4908 });
4909
4910 llvm::Value *packedI16 =
4911 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4912 llvm::Type::getInt16Ty(builder.getContext()));
4913
4915 args.push_back(packedI16);
4916 args.push_back(
4917 hasScale
4918 ? mt.lookupValue(curOp.getScaleFactor())
4919 : builder.getInt16(
4920 0x7f7f)); // default scale factor (value of 1 for both elements)
4921
4922 return {intId, std::move(args)};
4923}
4924
4925NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
4926 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4927 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4928
4929 bool hasRelu = curOp.getRelu();
4930
4931 llvm::Intrinsic::ID intId =
4933 .Case([&](Float4E2M1FNType type) {
4934 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4935 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4936 })
4937 .Default([](mlir::Type type) {
4938 llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
4939 return llvm::Intrinsic::not_intrinsic;
4940 });
4941
4942 llvm::Value *extendedI16 =
4943 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
4944 llvm::Type::getInt16Ty(builder.getContext()));
4945
4946 return {intId, {extendedI16}};
4947}
4948
4949NVVM::IDArgPair ConvertF4x2ToBF16x2Op::getIntrinsicIDAndArgs(
4950 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4951 auto curOp = cast<NVVM::ConvertF4x2ToBF16x2Op>(op);
4952 bool hasScale = static_cast<bool>(curOp.getScaleFactor());
4953 bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
4954 bool hasRelu = curOp.getRelu();
4955
4956 static constexpr llvm::Intrinsic::ID E2M1Ids[] = {
4957 llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_scale_n2_ue8m0,
4958 llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4959 llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4960 llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4961 };
4962
4963 unsigned idx = (hasSatfinite << 1) | hasRelu;
4964 llvm::Intrinsic::ID intId =
4966 .Case([&](Float4E2M1FNType type) { return E2M1Ids[idx]; })
4967 .Default([](mlir::Type type) {
4968 llvm_unreachable("Invalid type for ConvertF4x2ToBF16x2Op");
4969 return llvm::Intrinsic::not_intrinsic;
4970 });
4971
4972 llvm::Value *extendedI16 =
4973 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
4974 llvm::Type::getInt16Ty(builder.getContext()));
4975
4977 args.push_back(extendedI16);
4978 args.push_back(
4979 hasScale
4980 ? mt.lookupValue(curOp.getScaleFactor())
4981 : builder.getInt16(
4982 0x7f7f)); // default scale factor (value of 1 for both elements)
4983
4984 return {intId, std::move(args)};
4985}
4986
4987NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
4988 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4989 auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
4990 bool hasRelu = thisOp.getRelu();
4991 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
4992
4993 llvm::Intrinsic::ID id =
4994 hasRelu
4995 ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4996 : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4997
4998 // Fill the Intrinsic Args
5000 args.push_back(mt.lookupValue(thisOp.getA()));
5001 args.push_back(mt.lookupValue(thisOp.getB()));
5002 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
5003 : builder.getInt16(0x7f7f));
5004 return {id, std::move(args)};
5005}
5006
5007NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(
5008 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5009 auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
5010 bool hasRelu = thisOp.getRelu();
5011 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
5012
5013 llvm::Intrinsic::ID id =
5014 hasRelu
5015 ? llvm::Intrinsic::
5016 nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
5017 : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
5018
5019 // Fill the Intrinsic Args
5021 args.push_back(mt.lookupValue(thisOp.getSrc()));
5022 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
5023 : builder.getInt16(0x7f7f));
5024 return {id, std::move(args)};
5025}
5026
5027NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(
5028 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5029 auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
5030 bool hasRelu = thisOp.getRelu();
5031 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
5032 bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
5033
5034 static constexpr llvm::Intrinsic::ID ids[] = {
5035 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
5036 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
5037 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
5038 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
5039 };
5040
5041 unsigned idx = (hasSat << 1) | hasRelu;
5042
5043 // Fill the Intrinsic Args
5045 llvm::Value *packedI16 =
5046 builder.CreateBitCast(mt.lookupValue(thisOp.getSrc()),
5047 llvm::Type::getInt16Ty(builder.getContext()));
5048 args.push_back(packedI16);
5049 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
5050 : builder.getInt16(0x7f7f));
5051
5052 return {ids[idx], std::move(args)};
5053}
5054
5055llvm::Intrinsic::ID
5056Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
5059 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
5060 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
5061 .getAddressSpace();
5062 bool isShared = as == NVVMMemorySpace::Shared;
5063 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
5064
5065 llvm::Intrinsic::ID id;
5066 if (isShared) {
5067 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
5068 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
5069 } else {
5070 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
5071 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
5072 }
5073
5074 // Fill the Intrinsic Args
5075 args.push_back(mt.lookupValue(curOp.getAddr()));
5076 args.push_back(mt.lookupValue(curOp.getNCols()));
5077
5078 return id;
5079}
5080
5081llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
5084 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
5085 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
5086 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
5087 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
5088
5089 // Fill the Intrinsic Args
5090 args.push_back(mt.lookupValue(curOp.getTaddr()));
5091 args.push_back(mt.lookupValue(curOp.getNCols()));
5092
5093 return id;
5094}
5095
5096#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
5097 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
5098 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
5099
5100#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
5101 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
5102 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
5103
5104llvm::Intrinsic::ID
5105Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
5108 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
5109 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
5110 .getAddressSpace();
5111 bool isShared = as == NVVMMemorySpace::Shared;
5112 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
5113 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
5114
5115 llvm::Intrinsic::ID id =
5116 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
5117 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
5118
5119 // Fill the Intrinsic Args
5120 args.push_back(mt.lookupValue(curOp.getAddr()));
5121 if (hasMulticast)
5122 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
5123
5124 return id;
5125}
5126
5127#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
5128 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
5129
5130#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
5131 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
5132 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
5133
5134#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
5135 [&]() -> auto { \
5136 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
5137 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
5138 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
5139 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
5140 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
5141 }()
5142
5144ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
5146 llvm::IRBuilderBase &builder) {
5147 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
5148 llvm::Intrinsic::nvvm_ff2f16x2_rn,
5149 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
5150 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
5151 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
5152 };
5153 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
5154 llvm::Intrinsic::nvvm_ff2f16x2_rz,
5155 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
5156 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
5157 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
5158 };
5159 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
5160 llvm::Intrinsic::nvvm_ff2f16x2_rs,
5161 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
5162 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
5163 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
5164 };
5165
5166 unsigned hasRelu = op.getRelu() ? 1 : 0;
5167 unsigned hasSatFinite =
5168 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
5169 // idx: bit-0 - relu
5170 // bit-1 - satfinite
5171 unsigned idx = (hasSatFinite << 1) | hasRelu;
5172
5174 args.push_back(mt.lookupValue(op.getSrcHi()));
5175 args.push_back(mt.lookupValue(op.getSrcLo()));
5176 if (op.getRandomBits())
5177 args.push_back(mt.lookupValue(op.getRandomBits()));
5178
5179 switch (op.getRnd()) {
5180 case FPRoundingMode::RN:
5181 return {rndRNIds[idx], std::move(args)};
5182 case FPRoundingMode::RZ:
5183 return {rndRZIds[idx], std::move(args)};
5184 case FPRoundingMode::RS:
5185 return {rndRSIds[idx], std::move(args)};
5186 default:
5187 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
5188 }
5189}
5190
5192ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
5194 llvm::IRBuilderBase &builder) {
5195 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
5196 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
5197 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
5198 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
5199 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
5200 };
5201 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
5202 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
5203 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
5204 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
5205 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
5206 };
5207 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
5208 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
5209 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
5210 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
5211 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
5212 };
5213
5214 unsigned hasRelu = op.getRelu() ? 1 : 0;
5215 unsigned hasSatFinite =
5216 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
5217 // idx: bit-0 - relu
5218 // bit-1 - satfinite
5219 unsigned idx = (hasSatFinite << 1) | hasRelu;
5220
5222 args.push_back(mt.lookupValue(op.getSrcHi()));
5223 args.push_back(mt.lookupValue(op.getSrcLo()));
5224 if (op.getRandomBits())
5225 args.push_back(mt.lookupValue(op.getRandomBits()));
5226
5227 switch (op.getRnd()) {
5228 case FPRoundingMode::RN:
5229 return {rndRNIds[idx], std::move(args)};
5230 case FPRoundingMode::RZ:
5231 return {rndRZIds[idx], std::move(args)};
5232 case FPRoundingMode::RS:
5233 return {rndRSIds[idx], std::move(args)};
5234 default:
5235 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
5236 }
5237}
5238
5239llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
5240 mlir::Type dstTy = getDstTy();
5241 bool hasRelu = getRelu();
5242
5244 .Case([&](mlir::Float8E4M3FNType) {
5245 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
5246 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
5247 })
5248 .Case([&](mlir::Float8E5M2Type) {
5249 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
5250 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
5251 })
5252 .Default([](mlir::Type) {
5253 llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
5254 return llvm::Intrinsic::not_intrinsic;
5255 });
5256}
5257
5258llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
5259 mlir::Type dstTy = getDstTy();
5260 bool hasRelu = getRelu();
5261
5263 .Case([&](mlir::Float6E2M3FNType) {
5264 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
5265 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
5266 })
5267 .Case([&](mlir::Float6E3M2FNType) {
5268 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
5269 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
5270 })
5271 .Default([](mlir::Type) {
5272 llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
5273 return llvm::Intrinsic::not_intrinsic;
5274 });
5275}
5276
5277llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
5278 mlir::Type dstTy = getDstTy();
5279 bool hasRelu = getRelu();
5280
5282 .Case([&](mlir::Float4E2M1FNType) {
5283 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
5284 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
5285 })
5286 .Default([](mlir::Type) {
5287 llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
5288 return llvm::Intrinsic::not_intrinsic;
5289 });
5290}
5291
5292llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
5293 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
5294 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
5295 auto srcFmt = curOp.getSrcFormat();
5296 auto mc = curOp.getMulticast();
5297
5298 switch (curOp.getShape()) {
5299 case Tcgen05CpShape::SHAPE_128x256b:
5300 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
5301 case Tcgen05CpShape::SHAPE_128x128b:
5302 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
5303 case Tcgen05CpShape::SHAPE_4x256b:
5304 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
5305 case Tcgen05CpShape::SHAPE_32x128b:
5306 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
5307 case Tcgen05CpShape::SHAPE_64x128b:
5308 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
5309 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
5310 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
5311 }
5312 llvm_unreachable("Invalid shape in tcgen05 cp Op");
5313}
5314
5315// Returns the valid vector length for a given shape and vector length, the
5316// function models the table mentioned in the tcgen05.{ld, st} Op description
5317static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
5318 unsigned vecLen) {
5319 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
5320 return vecLen >= 2;
5321 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
5322 return vecLen >= 4;
5323 return true;
5324}
5325
5326LogicalResult Tcgen05LdOp::verify() {
5327 LogicalResult result = success();
5328 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
5329 result = emitError("shape 16x32bx2 requires offset argument");
5330
5331 if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
5332 result = emitError("offset argument is only supported for shape 16x32bx2");
5333
5334 auto resTy = getRes().getType();
5335 unsigned resLen = isa<VectorType>(resTy)
5336 ? llvm::cast<VectorType>(resTy).getNumElements()
5337 : 1;
5338 if (!isValidVectorLength(getShape(), resLen))
5339 result = emitError(llvm::formatv("invalid result type length {0} for shape "
5340 "{1} in tcgen05.ld Op",
5341 resLen, stringifyEnum(getShape())));
5342
5343 return result;
5344}
5345
5346LogicalResult Tcgen05StOp::verify() {
5347 LogicalResult result = success();
5348 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
5349 result = emitError("shape 16x32bx2 requires offset argument");
5350
5351 auto valTy = getVal().getType();
5352 unsigned valLen = isa<VectorType>(valTy)
5353 ? llvm::cast<VectorType>(valTy).getNumElements()
5354 : 1;
5355 if (!isValidVectorLength(getShape(), valLen))
5356 result = emitError(llvm::formatv("invalid input length {0} for shape "
5357 "{1} in tcgen05.st Op",
5358 valLen, stringifyEnum(getShape())));
5359
5360 return result;
5361}
5362
5363/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
5364/// have ConstantRangeAttr.
5367 SetIntRangeFn setResultRanges) {
5368 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
5369 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
5370 rangeAttr.getLower(), rangeAttr.getUpper()});
5371 } else {
5372 setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
5373 }
5374}
5375
5376/// Verify the range attribute satisfies LLVM ConstantRange constructor
5377/// requirements for NVVM SpecialRangeableRegisterOp.
5378static LogicalResult
5380 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
5381 if (!rangeAttr)
5382 return success();
5383
5384 const llvm::APInt &lower = rangeAttr->getLower();
5385 const llvm::APInt &upper = rangeAttr->getUpper();
5386
5387 // Check LLVM ConstantRange constructor condition
5388 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
5389 unsigned bitWidth = lower.getBitWidth();
5390 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
5391 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
5392 return op->emitOpError(
5393 "invalid range attribute: Lower == Upper, but they aren't min (")
5394 << llvm::toString(minVal, 10, false) << ") or max ("
5395 << llvm::toString(maxVal, 10, false)
5396 << ") value! This is an invalid constant range.";
5397 }
5398
5399 return success();
5400}
5401
5402static llvm::Value *getAsPackedI32(llvm::Value *arg,
5403 llvm::IRBuilderBase &builder) {
5404 return builder.CreateBitCast(arg,
5405 llvm::Type::getInt32Ty(builder.getContext()));
5406}
5407
5408NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
5409 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5410 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
5411
5413 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
5414 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
5415 args.push_back(mt.lookupValue(curOp.getC()));
5416
5417 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
5418 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
5419 unsigned type = (isASigned << 1) | isBSigned;
5420 const llvm::Intrinsic::ID ids[] = {
5421 llvm::Intrinsic::nvvm_idp4a_u_u,
5422 llvm::Intrinsic::nvvm_idp4a_u_s,
5423 llvm::Intrinsic::nvvm_idp4a_s_u,
5424 llvm::Intrinsic::nvvm_idp4a_s_s,
5425 };
5426 return {ids[type], args};
5427}
5428
5429NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
5430 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5431 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
5432
5434 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
5435 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
5436 args.push_back(builder.getInt1(curOp.getBHi()));
5437 args.push_back(mt.lookupValue(curOp.getC()));
5438
5439 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
5440 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
5441 unsigned type = (isASigned << 1) | isBSigned;
5442 const llvm::Intrinsic::ID ids[] = {
5443 llvm::Intrinsic::nvvm_idp2a_u_u,
5444 llvm::Intrinsic::nvvm_idp2a_u_s,
5445 llvm::Intrinsic::nvvm_idp2a_s_u,
5446 llvm::Intrinsic::nvvm_idp2a_s_s,
5447 };
5448 return {ids[type], args};
5449}
5450
5451static llvm::Value *getParamCastedAddr(llvm::Value *addr,
5452 llvm::IRBuilderBase &builder) {
5453 return builder.CreateAddrSpaceCast(
5454 addr, builder.getPtrTy(llvm::NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM));
5455}
5456
5458PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
5460 llvm::IRBuilderBase &builder) {
5461 using MemSpace = NVVM::NVVMMemorySpace;
5462 using CacheLevel = NVVM::PrefetchCacheLevel;
5463
5464 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
5465 std::optional<NVVM::CacheEvictionPriority> evictPriority =
5466 op.getEvictPriority();
5467 unsigned addressSpace =
5468 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
5469 .getAddressSpace();
5470
5472 llvm::Value *addr = mt.lookupValue(op.getAddr());
5473 args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
5474 : addr);
5475
5476 if (op.getTensormap())
5477 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
5478
5479 assert(cacheLevel && "expected cache level for non-tensormap prefetch");
5480
5481 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
5482 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
5483
5484 if (evictPriority && *cacheLevel == CacheLevel::L2) {
5485 switch (*evictPriority) {
5486 case NVVM::CacheEvictionPriority::EvictLast:
5487 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
5488 case NVVM::CacheEvictionPriority::EvictNormal:
5489 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
5490 default:
5491 llvm_unreachable("Invalid cache eviction priority");
5492 }
5493 }
5494
5495 switch (static_cast<MemSpace>(addressSpace)) {
5496 case MemSpace::Generic:
5497 return *cacheLevel == CacheLevel::L1
5498 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
5499 : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
5500 case MemSpace::Global:
5501 return *cacheLevel == CacheLevel::L1
5503 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
5504 : NVVM::IDArgPair(
5505 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
5506 case MemSpace::Local:
5507 return *cacheLevel == CacheLevel::L1
5509 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
5510 : NVVM::IDArgPair(
5511 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
5512 default:
5513 llvm_unreachable("Invalid pointer address space");
5514 }
5515}
5516
5517bool NVVM::InlinePtxOp::getAsmValues(
5518 RewriterBase &rewriter,
5519 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
5520 &asmValues) {
5521 for (auto arg : getReadWriteArgs())
5522 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
5523 for (auto arg : getResults())
5524 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
5525 for (auto arg : getReadOnlyArgs())
5526 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
5527 if (getPredicate())
5528 asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
5529 return false; // No manual mapping needed
5530}
5531
5532NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
5533 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5534 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
5536 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
5537 args.push_back(mt.lookupValue(curOp.getMbarrier()));
5538
5539 llvm::Intrinsic::ID intrinsicID =
5540 curOp.getMulticast()
5541 ? llvm::Intrinsic::
5542 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
5543 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
5544
5545 return {intrinsicID, args};
5546}
5547
5548NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
5549 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5550 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
5552 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
5553
5554 llvm::Intrinsic::ID intrinsicID;
5555
5556 switch (curOp.getQueryType()) {
5557 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
5558 intrinsicID =
5559 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
5560 break;
5561 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
5562 intrinsicID = llvm::Intrinsic::
5563 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
5564 break;
5565 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
5566 intrinsicID = llvm::Intrinsic::
5567 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
5568 break;
5569 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
5570 intrinsicID = llvm::Intrinsic::
5571 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
5572 break;
5573 }
5574 return {intrinsicID, args};
5575}
5576
5578PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
5579 llvm::IRBuilderBase &builder) {
5580 auto thisOp = cast<NVVM::PermuteOp>(op);
5581 NVVM::PermuteMode mode = thisOp.getMode();
5582
5583 static constexpr llvm::Intrinsic::ID IDs[] = {
5584 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
5585 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
5586 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
5587 llvm::Intrinsic::nvvm_prmt_rc16};
5588
5589 unsigned modeIndex = static_cast<unsigned>(mode);
5591 args.push_back(mt.lookupValue(thisOp.getLo()));
5592
5593 // Only first 3 modes (Default, f4e, b4e) need the hi operand.
5594 if (modeIndex < 3)
5595 args.push_back(mt.lookupValue(thisOp.getHi()));
5596
5597 args.push_back(mt.lookupValue(thisOp.getSelector()));
5598
5599 return {IDs[modeIndex], args};
5600}
5601
5602mlir::NVVM::IDArgPair TensormapReplaceOp::getIntrinsicIDAndArgs(
5603 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5604 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
5605
5607 args.push_back(mt.lookupValue(thisOp.getAddr()));
5608 if (thisOp.getOrd())
5609 args.push_back(builder.getInt32(thisOp.getOrd().value()));
5610 if (thisOp.getNewValue())
5611 args.push_back(mt.lookupValue(thisOp.getNewValue()));
5612 if (auto attr = thisOp.getNewValueAttr()) {
5613 auto val =
5615 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
5616 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
5617 TensormapFillModeAttr>([](auto attr) {
5618 return static_cast<unsigned>(attr.getValue());
5619 })
5620 .Default([](auto attr) {
5621 llvm_unreachable("Invalid attribute type");
5622 return 0;
5623 });
5624 args.push_back(builder.getInt32(val));
5625 }
5626
5627 static constexpr llvm::Intrinsic::ID IDs[] = {
5628 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
5629 llvm::Intrinsic::nvvm_tensormap_replace_rank,
5630 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
5631 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
5632 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
5633 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
5634 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
5635 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
5636 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
5637 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
5638 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
5639 };
5640
5641 unsigned fieldIndex = static_cast<unsigned>(thisOp.getField());
5642
5643 return {IDs[fieldIndex], args};
5644}
5645
5646//===----------------------------------------------------------------------===//
5647// NVVM tcgen05.mma functions
5648//===----------------------------------------------------------------------===//
5649
5651Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
5652 llvm::IRBuilderBase &builder) {
5653
5654 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
5656
5657 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5658
5659 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5660 const bool isATensor = isa<llvm::PointerType>(A->getType());
5661 args.push_back(A);
5662
5663 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5664 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5665 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5666
5667 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5668 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5669 using IsATensorArray = std::array<CtaGroupArray, 2>;
5670 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5671 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5672
5673 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5674 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
5675 { // without diable output lane
5676 {{// without scale input D
5677 {{
5678 // shared
5679 {{// cg1
5680 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
5681 // cg2
5682 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
5683 {{// tensor
5684 {
5685 // cg1
5686 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5687 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5688 },
5689 {
5690 // cg2
5691 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5692 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5693 }}},
5694 }},
5695 // with scale input D
5696 {{ // shared
5697 {{// cg1
5698 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
5699 // cg2
5700 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
5701 {{// tensor
5702 {
5703 // cg1
5704 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5705 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5706 },
5707 {
5708 // cg2
5709 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5710 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5711 }}}}}}},
5712 // with disable output lane
5713 {{ // without scale input D
5714 {{ // shared
5715 {{// cg1
5716 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
5717 notIntrinsic},
5718 // cg2
5719 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
5720 notIntrinsic}}},
5721 {{// cg1
5722 {
5723 llvm::Intrinsic::
5724 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
5725 llvm::Intrinsic::
5726 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
5727 },
5728 // cg2
5729 {
5730 llvm::Intrinsic::
5731 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
5732 llvm::Intrinsic::
5733 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
5734 }}}}},
5735 // with scale input D
5736 {{ // shared
5737 {{// cg1
5738 {llvm::Intrinsic::
5739 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
5740 notIntrinsic},
5741 // cg2
5742 {llvm::Intrinsic::
5743 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
5744 notIntrinsic}}},
5745 // tensor
5746 {{// cg1
5747 {llvm::Intrinsic::
5748 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
5749 llvm::Intrinsic::
5750 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
5751 // cg2
5752 {
5753 llvm::Intrinsic::
5754 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5755 llvm::Intrinsic::
5756 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5757 }}}}}}}}};
5758
5759 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5760 bool hasScaleInputD = ScaleInputD != nullptr;
5761
5762 llvm::Value *DisableOutputLane =
5763 mt.lookupValue(thisOp.getDisableOutputLane());
5764 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5765
5766 const unsigned ctaGroup =
5767 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5768
5769 llvm::Intrinsic::ID ID =
5770 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5771 [ctaGroup - 1][thisOp.getAShift()];
5772
5773 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
5774
5775 if (hasScaleInputD)
5776 args.push_back(ScaleInputD);
5777
5778 if (hasDisableOutputLane)
5779 args.push_back(DisableOutputLane);
5780
5781 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5782
5783 if (!hasDisableOutputLane)
5784 args.push_back(builder.getInt32(ctaGroup));
5785
5786 args.push_back(
5787 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5788
5789 return {ID, args};
5790}
5791
5792static LogicalResult
5793verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
5794 NVVM::CTAGroupKind ctaGroup, bool hasAShift,
5795 NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
5796
5797 if (disableOutputLane) {
5798 mlir::VectorType disableOutputLaneType =
5799 cast<mlir::VectorType>(disableOutputLane.getType());
5800 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5801 disableOutputLaneType.getNumElements() != 4) ||
5802 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5803 disableOutputLaneType.getNumElements() != 8))
5804 return emitError(loc) << "Disable Output Lane of length "
5805 << disableOutputLaneType.getNumElements()
5806 << " is incompatible with CtaGroupAttr";
5807 }
5808
5809 if (hasAShift && !isATensor)
5810 return emitError(
5811 loc, "A-shift can be applied only when matrix A is in tensor memory");
5812
5813 if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5814 collectorOp == Tcgen05MMACollectorOp::USE))
5815 return emitError(
5816 loc, "Cannot use collector buffer operation fill or use with ashift");
5817
5818 return success();
5819}
5820
5821LogicalResult Tcgen05MMAOp::verify() {
5822 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5823 getDisableOutputLane(), getCtaGroup(), getAShift(),
5824 getCollectorOp(), getLoc());
5825}
5826
5827//===----------------------------------------------------------------------===//
5828// NVVM tcgen05.mma.sp functions
5829//===----------------------------------------------------------------------===//
5830
5831mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
5832 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5833
5834 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5836
5837 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5838
5839 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5840 bool isATensor = isa<llvm::PointerType>(A->getType());
5841 args.push_back(A);
5842
5843 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5844 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5845 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5846 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5847
5848 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5849 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5850 using IsATensorArray = std::array<CtaGroupArray, 2>;
5851 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5852 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5853
5854 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5855 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5856 { // without diable output lane
5857 {{// without scale input D
5858 {{
5859 // shared
5860 {{// cg1
5861 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
5862 // cg2
5863 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
5864 {{// tensor
5865 {
5866 // cg1
5867 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5868 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5869 },
5870 {
5871 // cg2
5872 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5873 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5874 }}},
5875 }},
5876 // with scale input D
5877 {{ // shared
5878 {{// cg1
5879 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5880 notIntrinsic},
5881 // cg2
5882 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5883 notIntrinsic}}},
5884 {{// tensor
5885 {
5886 // cg1
5887 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5888 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5889 },
5890 {
5891 // cg2
5892 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5893 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5894 }}}}}}},
5895 // with disable output lane
5896 {{ // without scale input D
5897 {{ // shared
5898 {{// cg1
5899 {llvm::Intrinsic::
5900 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5901 notIntrinsic},
5902 // cg2
5903 {llvm::Intrinsic::
5904 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5905 notIntrinsic}}},
5906 {{// cg1
5907 {
5908 llvm::Intrinsic::
5909 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5910 llvm::Intrinsic::
5911 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5912 },
5913 // cg2
5914 {
5915 llvm::Intrinsic::
5916 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5917 llvm::Intrinsic::
5918 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5919 }}}}},
5920 // with scale input D
5921 {{ // shared
5922 {{// cg1
5923 {llvm::Intrinsic::
5924 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5925 notIntrinsic},
5926 // cg2
5927 {llvm::Intrinsic::
5928 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5929 notIntrinsic}}},
5930 // tensor
5931 {{// cg1
5932 {llvm::Intrinsic::
5933 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5934 llvm::Intrinsic::
5935 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5936 // cg2
5937 {
5938 llvm::Intrinsic::
5939 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5940 llvm::Intrinsic::
5941 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5942 }}}}}}}}};
5943
5944 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5945 bool hasScaleInputD = ScaleInputD != nullptr;
5946
5947 llvm::Value *DisableOutputLane =
5948 mt.lookupValue(thisOp.getDisableOutputLane());
5949 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5950
5951 unsigned ctaGroup =
5952 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5953
5954 llvm::Intrinsic::ID ID =
5955 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5956 [ctaGroup - 1][thisOp.getAShift()];
5957
5958 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
5959
5960 if (hasScaleInputD)
5961 args.push_back(ScaleInputD);
5962
5963 if (hasDisableOutputLane)
5964 args.push_back(DisableOutputLane);
5965
5966 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5967
5968 if (!hasDisableOutputLane)
5969 args.push_back(builder.getInt32(ctaGroup));
5970
5971 args.push_back(
5972 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5973
5974 return {ID, args};
5975}
5976
5977LogicalResult Tcgen05MMASparseOp::verify() {
5978 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5979 getDisableOutputLane(), getCtaGroup(), getAShift(),
5980 getCollectorOp(), getLoc());
5981}
5982
5983//===----------------------------------------------------------------------===//
5984// NVVM tcgen05.mma.block_scale functions
5985//===----------------------------------------------------------------------===//
5986
5987mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
5988 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5989
5990 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5992
5993 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5994
5995 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5996 bool isATensor = isa<llvm::PointerType>(A->getType());
5997 args.push_back(A);
5998
5999 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
6000 args.push_back(mt.lookupValue(thisOp.getIdesc()));
6001 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
6002 args.push_back(mt.lookupValue(thisOp.getScaleA()));
6003 args.push_back(mt.lookupValue(thisOp.getScaleB()));
6004 args.push_back(builder.getInt32(
6005 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
6006 args.push_back(
6007 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
6008
6009 auto kind = thisOp.getKind();
6010 auto blockScale = thisOp.getBlockScale();
6011 llvm::Intrinsic::ID ID = [&]() {
6012 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
6013 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
6014 return isATensor ? llvm::Intrinsic::
6015 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
6016 : llvm::Intrinsic::
6017 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
6018 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6019 return isATensor
6020 ? llvm::Intrinsic::
6021 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
6022 : llvm::Intrinsic::
6023 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
6024 }
6025 } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
6026 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
6027 return isATensor
6028 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
6029 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
6030 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6031 return isATensor ? llvm::Intrinsic::
6032 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
6033 : llvm::Intrinsic::
6034 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
6035 }
6036 } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
6037 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6038 return isATensor
6039 ? llvm::Intrinsic::
6040 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
6041 : llvm::Intrinsic::
6042 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
6043
6044 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
6045 return isATensor
6046 ? llvm::Intrinsic::
6047 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
6048 : llvm::Intrinsic::
6049 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
6050 }
6051 }
6052 llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
6053 }();
6054
6055 return {ID, args};
6056}
6057
6058static LogicalResult verifyTcgen05MMABlockScaleOp(
6059 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
6060 NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
6061 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
6062 kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
6063 return emitError(loc, "mxf4nvf4 requires block scale attribute");
6064
6065 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
6066 kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
6067 return emitError(loc,
6068 llvm::formatv("{} kind does not support block16 attribute",
6069 stringifyEnum(kind)));
6070
6071 return success();
6072}
6073
6074LogicalResult Tcgen05MMABlockScaleOp::verify() {
6075 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
6076 getBlockScale(), getLoc());
6077}
6078
6079//===----------------------------------------------------------------------===//
6080// NVVM tcgen05.mma.sp.block_scale functions
6081//===----------------------------------------------------------------------===//
6082
6083mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
6084 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
6085
6086 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
6088
6089 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
6090
6091 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
6092 bool isATensor = isa<llvm::PointerType>(A->getType());
6093 args.push_back(A);
6094
6095 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
6096 args.push_back(mt.lookupValue(thisOp.getIdesc()));
6097 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
6098 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
6099 args.push_back(mt.lookupValue(thisOp.getScaleA()));
6100 args.push_back(mt.lookupValue(thisOp.getScaleB()));
6101 args.push_back(builder.getInt32(
6102 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
6103 args.push_back(
6104 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
6105
6106 auto kind = thisOp.getKind();
6107 auto blockScale = thisOp.getBlockScale();
6108 llvm::Intrinsic::ID ID = [&]() {
6109 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
6110 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
6111 return isATensor ? llvm::Intrinsic::
6112 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
6113 : llvm::Intrinsic::
6114 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
6115 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6116 return isATensor
6117 ? llvm::Intrinsic::
6118 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
6119 : llvm::Intrinsic::
6120 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
6121 }
6122 } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
6123 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
6124 return isATensor ? llvm::Intrinsic::
6125 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
6126 : llvm::Intrinsic::
6127 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
6128 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6129 return isATensor
6130 ? llvm::Intrinsic::
6131 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
6132 : llvm::Intrinsic::
6133 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
6134 }
6135 } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
6136 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6137 return isATensor
6138 ? llvm::Intrinsic::
6139 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
6140 : llvm::Intrinsic::
6141 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
6142
6143 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
6144 return isATensor
6145 ? llvm::Intrinsic::
6146 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
6147 : llvm::Intrinsic::
6148 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
6149 }
6150 }
6151 llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
6152 }();
6153
6154 return {ID, args};
6155}
6156
6157LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
6158 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
6159 getBlockScale(), getLoc());
6160}
6161
6162//===----------------------------------------------------------------------===//
6163// NVVM tcgen05.mma.ws functions
6164//===----------------------------------------------------------------------===//
6165
6166mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
6167 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
6168
6169 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
6171
6172 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
6173
6174 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
6175 bool isATensor = isa<llvm::PointerType>(A->getType());
6176 args.push_back(A);
6177
6178 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
6179 args.push_back(mt.lookupValue(thisOp.getIdesc()));
6180 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
6181
6182 mlir::Value ZeroColMask = thisOp.getZeroColMask();
6183 llvm::Intrinsic::ID ID = notIntrinsic;
6184 if (ZeroColMask) {
6185 args.push_back(mt.lookupValue(ZeroColMask));
6186 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
6187 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
6188 } else
6189 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
6190 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
6191
6192 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
6193 args.push_back(
6194 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
6195 args.push_back(
6196 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
6197
6198 return {ID, args};
6199}
6200
6201//===----------------------------------------------------------------------===//
6202// NVVM tcgen05.mma.ws.sp functions
6203//===----------------------------------------------------------------------===//
6204
6205mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
6206 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
6207
6208 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
6210
6211 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
6212
6213 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
6214 bool isATensor = isa<llvm::PointerType>(A->getType());
6215 args.push_back(A);
6216
6217 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
6218 args.push_back(mt.lookupValue(thisOp.getIdesc()));
6219 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
6220 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
6221
6222 mlir::Value ZeroColMask = thisOp.getZeroColMask();
6223 llvm::Intrinsic::ID ID = notIntrinsic;
6224 if (ZeroColMask) {
6225 args.push_back(mt.lookupValue(ZeroColMask));
6226 ID = isATensor
6227 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
6228 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
6229 } else
6230 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
6231 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
6232
6233 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
6234 args.push_back(
6235 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
6236 args.push_back(
6237 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
6238
6239 return {ID, args};
6240}
6241
6242//===----------------------------------------------------------------------===//
6243// NVVM tcgen05.ld.red functions
6244//===----------------------------------------------------------------------===//
6245
6246#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
6247 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
6248
6249mlir::NVVM::IDArgPair NVVM::Tcgen05LdRedOp::getIntrinsicIDAndArgs(
6250 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
6251 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
6253
6254 mlir::VectorType VecResTy =
6255 cast<mlir::VectorType>(thisOp.getData().getType());
6256 unsigned Num = VecResTy.getNumElements();
6257 bool IsFloat = thisOp.getRedVal().getType().isF32();
6258
6259 llvm::Intrinsic::ID Shape32x32b[][2] = {
6261 {TCGEN05LDRED(32x32b, x2, i32), TCGEN05LDRED(32x32b, x2, f32)},
6262 {TCGEN05LDRED(32x32b, x4, i32), TCGEN05LDRED(32x32b, x4, f32)},
6263 {TCGEN05LDRED(32x32b, x8, i32), TCGEN05LDRED(32x32b, x8, f32)},
6264 {TCGEN05LDRED(32x32b, x16, i32), TCGEN05LDRED(32x32b, x16, f32)},
6265 {TCGEN05LDRED(32x32b, x32, i32), TCGEN05LDRED(32x32b, x32, f32)},
6266 {TCGEN05LDRED(32x32b, x64, i32), TCGEN05LDRED(32x32b, x64, f32)},
6267 {TCGEN05LDRED(32x32b, x128, i32), TCGEN05LDRED(32x32b, x128, f32)},
6268 };
6269
6270 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
6272 {TCGEN05LDRED(16x32bx2, x2, i32), TCGEN05LDRED(16x32bx2, x2, f32)},
6273 {TCGEN05LDRED(16x32bx2, x4, i32), TCGEN05LDRED(16x32bx2, x4, f32)},
6274 {TCGEN05LDRED(16x32bx2, x8, i32), TCGEN05LDRED(16x32bx2, x8, f32)},
6275 {TCGEN05LDRED(16x32bx2, x16, i32), TCGEN05LDRED(16x32bx2, x16, f32)},
6276 {TCGEN05LDRED(16x32bx2, x32, i32), TCGEN05LDRED(16x32bx2, x32, f32)},
6277 {TCGEN05LDRED(16x32bx2, x64, i32), TCGEN05LDRED(16x32bx2, x64, f32)},
6278 {TCGEN05LDRED(16x32bx2, x128, i32), TCGEN05LDRED(16x32bx2, x128, f32)},
6279 };
6280
6281 NVVM::Tcgen05LdStShape shape = thisOp.getShape();
6282 unsigned ID = [&]() {
6283 // `num` contains the length of vector and log2 of `num` returns the index
6284 // into the shape array
6285 unsigned idx = std::log2(Num);
6286 switch (shape) {
6287 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
6288 return Shape32x32b[idx][IsFloat];
6289 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
6290 return Shape16x32bx2[idx][IsFloat];
6291 default:
6292 llvm_unreachable("unhandled tcgen05.ld lowering");
6293 }
6294 }();
6295
6296 args.push_back(mt.lookupValue(thisOp.getAddr()));
6297
6298 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
6299 args.push_back(mt.lookupValue(thisOp.getOffset()));
6300
6301 args.push_back(
6302 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
6303
6304 if (IsFloat) {
6305 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getAbs())));
6306 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getNan())));
6307 }
6308 return {ID, args};
6309}
6310
6311LogicalResult Tcgen05LdRedOp::verify() {
6312 VectorType data = cast<VectorType>(getData().getType());
6313 Type redVal = getRedVal().getType();
6314
6315 if (data.getElementType() != redVal)
6316 return emitError(
6317 "type of reduction value and element type of vector data should match");
6318
6319 if (getOp() != NVVM::ReductionKind::MIN &&
6320 getOp() != NVVM::ReductionKind::MAX)
6321 return emitError("only min and max reduction kinds are supported");
6322
6323 if (redVal.isInteger() && (getAbs() || getNan())) {
6324 return emitError("abs or nan is only applicable for f32 type");
6325 }
6326 return success();
6327}
6328
6329//===----------------------------------------------------------------------===//
6330// NVVMDialect initialization, type parsing, and registration.
6331//===----------------------------------------------------------------------===//
6332
6333// TODO: This should be the llvm.nvvm dialect once this is supported.
6334void NVVMDialect::initialize() {
6335 addOperations<
6336#define GET_OP_LIST
6337#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6338 >();
6339 addAttributes<
6340#define GET_ATTRDEF_LIST
6341#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
6342 >();
6343
6344 // Support unknown operations because not all NVVM operations are
6345 // registered.
6346 allowUnknownOperations();
6347 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
6348 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
6349}
6350
6351LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
6352 NamedAttribute attr) {
6353 StringAttr attrName = attr.getName();
6354 // Kernel function attribute should be attached to functions.
6355 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
6356 if (!isa<LLVM::LLVMFuncOp>(op)) {
6357 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
6358 << "' attribute attached to unexpected op";
6359 }
6360 }
6361 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
6362 // dim
6363 if (attrName == NVVMDialect::getMaxntidAttrName() ||
6364 attrName == NVVMDialect::getReqntidAttrName() ||
6365 attrName == NVVMDialect::getClusterDimAttrName()) {
6366 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
6367 if (!values || values.empty() || values.size() > 3) {
6368 return op->emitError()
6369 << "'" << attrName
6370 << "' attribute must be integer array with maximum 3 index";
6371 }
6372 }
6373 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
6374 // attribute
6375 if (attrName == NVVMDialect::getMinctasmAttrName() ||
6376 attrName == NVVMDialect::getMaxnregAttrName() ||
6377 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
6378 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
6379 return op->emitError()
6380 << "'" << attrName << "' attribute must be integer constant";
6381 }
6382 }
6383 // blocksareclusters must be used along with reqntid and cluster_dim
6384 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
6385 if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
6386 !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
6387 return op->emitError()
6388 << "'" << attrName << "' attribute must be used along with " << "'"
6389 << NVVMDialect::getReqntidAttrName() << "' and " << "'"
6390 << NVVMDialect::getClusterDimAttrName() << "'";
6391 }
6392 }
6393
6394 return success();
6395}
6396
6397LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
6398 unsigned regionIndex,
6399 unsigned argIndex,
6400 NamedAttribute argAttr) {
6401 auto funcOp = dyn_cast<FunctionOpInterface>(op);
6402 if (!funcOp)
6403 return success();
6404
6405 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
6406 StringAttr attrName = argAttr.getName();
6407 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
6408 if (!isKernel) {
6409 return op->emitError()
6410 << "'" << attrName
6411 << "' attribute must be present only on kernel arguments";
6412 }
6413 if (!isa<UnitAttr>(argAttr.getValue()))
6414 return op->emitError() << "'" << attrName << "' must be a unit attribute";
6415 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
6416 return op->emitError()
6417 << "'" << attrName
6418 << "' attribute requires the argument to also have attribute '"
6419 << LLVM::LLVMDialect::getByValAttrName() << "'";
6420 }
6421 }
6422
6423 return success();
6424}
6425
6426//===----------------------------------------------------------------------===//
6427// NVVM Address Space Attr
6428//===----------------------------------------------------------------------===//
6429
6430unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
6431 return static_cast<unsigned>(getValue());
6432}
6433
6434bool NVVMMemorySpaceAttr::isValidLoad(
6435 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
6436 const ::mlir::DataLayout *dataLayout,
6438 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
6439 dataLayout, emitError);
6440}
6441
6442bool NVVMMemorySpaceAttr::isValidStore(
6443 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
6444 const ::mlir::DataLayout *dataLayout,
6446 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
6447 dataLayout, emitError);
6448}
6449
6450bool NVVMMemorySpaceAttr::isValidAtomicOp(
6451 ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
6452 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
6454 // TODO: update this method once `ptr.atomic_rmw` is implemented.
6455 assert(false && "unimplemented, see TODO in the source.");
6456 return false;
6457}
6458
6459bool NVVMMemorySpaceAttr::isValidAtomicXchg(
6460 Type type, ptr::AtomicOrdering successOrdering,
6461 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
6462 const ::mlir::DataLayout *dataLayout,
6464 // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
6465 assert(false && "unimplemented, see TODO in the source.");
6466 return false;
6467}
6468
6469bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
6470 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
6471 // TODO: update this method once the `ptr.addrspace_cast` op is added to the
6472 // dialect.
6473 assert(false && "unimplemented, see TODO in the source.");
6474 return false;
6475}
6476
6477bool NVVMMemorySpaceAttr::isValidPtrIntCast(
6478 Type intLikeTy, Type ptrLikeTy,
6480 // TODO: update this method once the int-cast ops are added to the `ptr`
6481 // dialect.
6482 assert(false && "unimplemented, see TODO in the source.");
6483 return false;
6484}
6485
6486//===----------------------------------------------------------------------===//
6487// NVVM target attribute.
6488//===----------------------------------------------------------------------===//
6489LogicalResult
6490NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
6491 int optLevel, StringRef triple, StringRef chip,
6492 StringRef features, DictionaryAttr flags,
6493 ArrayAttr files, bool verifyTarget) {
6494 if (optLevel < 0 || optLevel > 3) {
6495 emitError() << "The optimization level must be a number between 0 and 3.";
6496 return failure();
6497 }
6498 if (triple.empty()) {
6499 emitError() << "The target triple cannot be empty.";
6500 return failure();
6501 }
6502 if (chip.empty()) {
6503 emitError() << "The target chip cannot be empty.";
6504 return failure();
6505 }
6506 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
6507 return mlir::isa_and_nonnull<StringAttr>(attr);
6508 })) {
6509 emitError() << "All the elements in the `link` array must be strings.";
6510 return failure();
6511 }
6512 return success();
6513}
6514
6515LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
6516 if (!getVerifyTarget())
6517 return success();
6518
6519 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
6520 if (!gpuModuleOp) {
6521 return emitError(gpuModule->getLoc(),
6522 "NVVM target attribute must be attached to a GPU module");
6523 }
6524
6525 const unsigned targetFullSmVersion =
6527 if (!NVVMCheckSMVersion::isMinimumSMVersion(targetFullSmVersion)) {
6528 return emitError(gpuModule->getLoc(),
6529 "Minimum NVVM target SM version is sm_20");
6530 }
6531
6532 if (gpuModuleOp
6533 ->walk([&](Operation *op) {
6534 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
6535 const NVVMCheckSMVersion requirement =
6536 reqOp.getRequiredMinSMVersion();
6537 if (!requirement.isCompatibleWith(targetFullSmVersion)) {
6538 op->emitOpError() << "is not supported on " << getChip();
6539 return WalkResult::interrupt();
6540 }
6541 }
6542 return WalkResult::advance();
6543 })
6544 .wasInterrupted())
6545 return failure();
6546
6547 return success();
6548}
6549
6550#define GET_OP_CLASSES
6551#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6552
6553#define GET_ATTRDEF_CLASSES
6554#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
#define _none
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred, TypeRange actual)
For ops with optional results, allow the user to omit the result even when inference would produce on...
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyAddSubFOp(OpType op)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static llvm::Intrinsic::ID getBarrierReductionIntrinsic(bool aligned, NVVM::BarrierReduction kind)
Maps the (aligned, kind) pair to the @llvm.nvvm.barrier.cta.red.
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &regs)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > &regTypes, SmallVectorImpl< StringRef > &ignoreAttrNames)
static llvm::Intrinsic::ID getBarrierSyncIntrinsic(bool aligned, bool hasCount)
Maps the (aligned, hasCount) pair to the @llvm.nvvm.barrier.cta.sync.
static constexpr unsigned notIntrinsic
static LogicalResult inferMBarrierArriveResultTypes(MLIRContext *context, Value addr, SmallVectorImpl< Type > &inferredReturnTypes)
Only shared_cluster (ptr<7>) produces zero results; all other address spaces (including generic) retu...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerType getI16Type()
Definition Builders.cpp:65
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:575
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:585
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition NVVMDialect.h:53
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult matchAndRewrite(SubFOp op, PatternRewriter &rewriter) const override
static bool isMinimumSMVersion(unsigned fullSmVersion)
static unsigned getTargetFullSmVersionFromStr(StringRef smVersionString)
bool isCompatibleWith(const unsigned &targetFullSmVersion) const
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.