MLIR 23.0.0git
NVVMDialect.cpp
Go to the documentation of this file.
1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
39#include <cassert>
40#include <optional>
41#include <string>
42
43using namespace mlir;
44using namespace NVVM;
45
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
50
51//===----------------------------------------------------------------------===//
52// Helper/Utility methods
53//===----------------------------------------------------------------------===//
54
55static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
57 return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
58}
59
61 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
62}
63
65 return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
66}
67
69 return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
70}
71
72static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
73 llvm::Value *ptr,
74 NVVMMemorySpace targetAS) {
75 unsigned AS = static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
78}
79
80// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
81static llvm::nvvm::CTAGroupKind
82getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
83 switch (ctaGroup) {
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
88 }
89 llvm_unreachable("unsupported cta_group value");
90}
91
92//===----------------------------------------------------------------------===//
93// Verifier methods
94//===----------------------------------------------------------------------===//
95
96// This verifier is shared among the following Ops:
97// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
98// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
99static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
100 bool isIm2Col,
101 size_t numIm2ColOffsets,
102 Location loc) {
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc, "expects coordinates between 1 to 5 dimension");
105
106 // For Im2Col mode, there are two constraints:
107 if (isIm2Col) {
108 // 1. Tensor must always be at least 3-d.
109 if (tensorDims < 3)
110 return emitError(
111 loc,
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
113 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
115 return emitError(
116 loc, "im2col offsets must be 2 less than number of coordinates");
117 }
118 return success();
119}
120
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
123 // We lower through inline-ptx when getPredicate() is true.
124 // a) Only TILE mode is supported
125 // b) Cache-hint is not supported
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError("Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
131 }
132
133 size_t dims = getCoordinates().size();
134 switch (mode) {
135 case TMAStoreMode::TILE:
136 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
137 case TMAStoreMode::IM2COL:
138 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
139 case TMAStoreMode::TILE_SCATTER4:
140 if (dims != 5)
141 return emitError("Scatter4 mode expects 5 coordinates");
142 }
143 return success();
144}
145
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError("Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError("expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError("CG cache modifier is only support for 16 bytes copy.");
154 return success();
155}
156
157// This verify params can be shared across TMA Load and Prefetch Ops.
158static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
159 TMALoadMode mode, Location loc) {
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc, "expects coordinates between 1 to 5 dimension");
162
163 auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
166 return emitError(loc)
167 << "to use " << mode
168 << " mode, the tensor has to be at least 3-dimensional";
169
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
172 << " (provided " << numIm2colOff << ")";
173
174 return success();
175 };
176
177 switch (mode) {
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode, false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode, true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode, true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode, false, 0)
188 : emitError(loc, "Gather4 mode expects 5 coordinates");
189 }
190 return success();
191}
192
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
194 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
195 getMode(), getLoc());
196}
197
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) { // Inline-asm based lowering
202 if (isCTAOnly)
203 return emitError("Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
205 return emitError(
206 "Predicate is supported only for Tile and Im2col modes.");
207 } else { // Intrinsics-based lowering
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
211 .getAddressSpace();
212 if (AS != expectedAS)
213 return emitError()
214 << (isCTAOnly
215 ? "Shared::cta destination requires address-space 3."
216 : "Shared::cluster destination requires address-space 7.");
217 // Checks specific to shared::cta mode
218 if (isCTAOnly) {
219 if (getMulticastMask())
220 return emitError("Multicast is not supported with shared::cta mode.");
221 if (getGroup())
222 return emitError("CTAGroup is not supported with shared::cta mode.");
223 }
224 }
225
226 return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
227 getMode(), getLoc());
228}
229
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
232 size_t dims = getCoordinates().size();
233 switch (mode) {
234 case TMAStoreMode::TILE:
235 return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
236 case TMAStoreMode::IM2COL:
237 return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
240 }
241 return success();
242}
243
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
245 bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
246 if (isSharedCTA && getMulticastMask())
247 return emitError("Multicast is not supported with shared::cta mode.");
248
249 return success();
250}
251
252static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
253 NVVM::MemScopeKind scope,
254 Value retVal = nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->emitError("mbarrier scope must be either CTA or Cluster");
257
258 bool isSharedCluster = isPtrInSharedClusterSpace(addr);
259 bool hasRetValue = static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
261 return op->emitError(
262 "mbarrier in shared_cluster space cannot return any value");
263
264 return success();
265}
266
267LogicalResult MBarrierArriveOp::verify() {
268 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
269 getRes());
270}
271
272LogicalResult MBarrierArriveDropOp::verify() {
273 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
274 getRes());
275}
276
277LogicalResult MBarrierArriveExpectTxOp::verify() {
278 // The inline-ptx version of this Op does not support all features.
279 // With predicate, this Op lowers to inline-ptx. So, verify and
280 // error-out if there are unsupported features.
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError("mbarrier scope must be CTA when using predicate");
284
285 if (isPtrInSharedClusterSpace(getAddr()))
286 return emitError("mbarrier in shared_cluster space is not supported when "
287 "using predicate");
288
289 if (getRes())
290 return emitError("return-value is not supported when using predicate");
291
292 if (getRelaxed() == true)
293 return emitError("mbarrier with relaxed semantics is not supported when "
294 "using predicate");
295 }
296 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
297 getRes());
298}
299
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
301 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
302 getRes());
303}
304
305LogicalResult MBarrierExpectTxOp::verify() {
306 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
307}
308
309LogicalResult MBarrierCompleteTxOp::verify() {
310 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
311}
312
313LogicalResult MBarrierTestWaitOp::verify() {
314 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
315}
316
317LogicalResult MBarrierTryWaitOp::verify() {
318 return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
319}
320
321LogicalResult ConvertFloatToTF32Op::verify() {
322 using RndMode = NVVM::FPRoundingMode;
323 switch (getRnd()) {
324 case RndMode::RNA:
325 if (getRelu())
326 return emitError("Relu not supported with rna rounding mode.");
327 break;
328 case RndMode::RN:
329 case RndMode::RZ:
330 break;
331 default:
332 return emitError(
333 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
334 }
335 return success();
336}
337
338LogicalResult ConvertF32x2ToF6x2Op::verify() {
340
341 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
342 return emitOpError("Only ")
343 << mlir::Float6E2M3FNType::get(ctx) << " and "
344 << mlir::Float6E3M2FNType::get(ctx)
345 << " types are supported for conversions from f32x2 to f6x2.";
346 }
347 return success();
348}
349
350LogicalResult ConvertF32x2ToF8x2Op::verify() {
351 using RndMode = NVVM::FPRoundingMode;
352 using SatMode = NVVM::SaturationMode;
353
354 bool isRoundingModeRN = getRnd() == RndMode::RN;
355 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
356 bool isRoundingModeRP = getRnd() == RndMode::RP;
357 bool isSatFinite = getSat() == SatMode::SATFINITE;
358
359 bool hasRelu = getRelu();
360
362
364 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
365 [&](mlir::Type) -> LogicalResult {
366 if (!isRoundingModeRN) {
367 return emitOpError("Only RN rounding mode is supported for "
368 "conversions from f32x2 to ")
369 << mlir::Float8E4M3FNType::get(ctx) << " and "
370 << mlir::Float8E5M2Type::get(ctx) << " types";
371 }
372 if (!isSatFinite) {
373 return emitOpError("Only SATFINITE saturation mode is supported "
374 "for conversions "
375 "from f32x2 to ")
376 << mlir::Float8E4M3FNType::get(ctx) << " and "
377 << mlir::Float8E5M2Type::get(ctx) << " types";
378 }
379 return success();
380 })
381 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
382 if (!(isRoundingModeRZ || isRoundingModeRP)) {
383 return emitOpError("Only RZ and RP rounding modes are supported for "
384 "conversions from f32x2 to ")
385 << mlir::Float8E8M0FNUType::get(ctx) << " type";
386 }
387 if (hasRelu) {
388 return emitOpError("relu not supported for conversions to ")
389 << mlir::Float8E8M0FNUType::get(ctx) << " type";
390 }
391 return success();
392 })
393 .Default([&](mlir::Type) {
394 return emitOpError("Only ")
395 << mlir::Float8E4M3FNType::get(ctx) << ", "
396 << mlir::Float8E5M2Type::get(ctx) << ", and "
397 << mlir::Float8E8M0FNUType::get(ctx)
398 << " types are "
399 "supported for conversions from f32x2 to f8x2";
400 });
401}
402
403LogicalResult ConvertF16x2ToF8x2Op::verify() {
405
406 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
407 return emitOpError("Only ")
408 << mlir::Float8E4M3FNType::get(ctx) << " and "
409 << mlir::Float8E5M2Type::get(ctx)
410 << " types are supported for conversions from f16x2 to f8x2.";
411 }
412 return success();
413}
414
415LogicalResult ConvertBF16x2ToF8x2Op::verify() {
416 using RndMode = NVVM::FPRoundingMode;
417 using SatMode = NVVM::SaturationMode;
418
419 bool isRoundingModeRN = getRnd() == RndMode::RN;
420 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
421 bool isRoundingModeRP = getRnd() == RndMode::RP;
422 bool isSatFinite = getSat() == SatMode::SATFINITE;
423 bool hasRelu = getRelu();
424
426
428 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
429 [&](mlir::Type) -> LogicalResult {
430 if (!isRoundingModeRN)
431 return emitOpError("Only RN rounding mode is supported for "
432 "conversions from bf16x2 to ")
433 << mlir::Float8E4M3FNType::get(ctx) << " and "
434 << mlir::Float8E5M2Type::get(ctx) << " types";
435 if (!isSatFinite)
436 return emitOpError("Only SATFINITE saturation mode is supported "
437 "for conversions from bf16x2 to ")
438 << mlir::Float8E4M3FNType::get(ctx) << " and "
439 << mlir::Float8E5M2Type::get(ctx) << " types";
440 return success();
441 })
442 .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
443 if (!(isRoundingModeRZ || isRoundingModeRP))
444 return emitOpError("Only RZ and RP rounding modes are supported for "
445 "conversions from bf16x2 to ")
446 << mlir::Float8E8M0FNUType::get(ctx) << " type";
447 if (hasRelu)
448 return emitOpError("relu not supported for conversions to ")
449 << mlir::Float8E8M0FNUType::get(ctx) << " type";
450 return success();
451 })
452 .Default([&](mlir::Type) -> LogicalResult {
453 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
454 return failure();
455 });
456}
457
458LogicalResult ConvertF32x2ToF4x2Op::verify() {
460
461 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
462 return emitOpError("Only ")
463 << mlir::Float4E2M1FNType::get(ctx)
464 << " type is supported for conversions from f32x2 to f4x2.";
465
466 return success();
467}
468
469LogicalResult ConvertF8x2ToF16x2Op::verify() {
471
472 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
473 return emitOpError("Only ")
474 << mlir::Float8E4M3FNType::get(ctx) << " and "
475 << mlir::Float8E5M2Type::get(ctx)
476 << " types are supported for conversions from f8x2 to f16x2.";
477
478 return success();
479}
480
481LogicalResult ConvertF8x2ToBF16x2Op::verify() {
483 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
484 return emitOpError("Only ")
485 << mlir::Float8E8M0FNUType::get(ctx)
486 << " type is supported for conversions from f8x2 to bf16x2.";
487
488 return success();
489}
490
491LogicalResult ConvertF6x2ToF16x2Op::verify() {
493
494 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
495 return emitOpError("Only ")
496 << mlir::Float6E2M3FNType::get(ctx) << " and "
497 << mlir::Float6E3M2FNType::get(ctx)
498 << " types are supported for conversions from f6x2 to f16x2.";
499
500 return success();
501}
502
503LogicalResult ConvertF4x2ToF16x2Op::verify() {
505
506 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
507 return emitOpError("Only ")
508 << mlir::Float4E2M1FNType::get(ctx)
509 << " type is supported for conversions from f4x2 to f16x2.";
510
511 return success();
512}
513
514LogicalResult PermuteOp::verify() {
515 using Mode = NVVM::PermuteMode;
516 bool hasHi = static_cast<bool>(getHi());
517
518 switch (getMode()) {
519 case Mode::DEFAULT:
520 case Mode::F4E:
521 case Mode::B4E:
522 if (!hasHi)
523 return emitError("mode '") << getMode() << "' requires 'hi' operand.";
524 break;
525 case Mode::RC8:
526 case Mode::ECL:
527 case Mode::ECR:
528 case Mode::RC16:
529 if (hasHi)
530 return emitError("mode '")
531 << getMode() << "' does not accept 'hi' operand.";
532 break;
533 }
534
535 return success();
536}
537
538//===----------------------------------------------------------------------===//
539// Stochastic Rounding Conversion Ops
540//===----------------------------------------------------------------------===//
541
542static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
543 FPRoundingMode rnd,
544 bool hasRandomBits,
545 Operation *op) {
546 static constexpr FPRoundingMode validRndModes[] = {
547 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
548
549 if (!llvm::is_contained(validRndModes, rnd)) {
550 return op->emitOpError(
551 "Only RN, RZ, and RS rounding modes are supported for "
552 "conversions from f32x2 to ")
553 << dstType << ".";
554 }
555
556 if (rnd == FPRoundingMode::RS) {
557 if (!hasRandomBits) {
558 return op->emitOpError("random_bits is required for RS rounding mode.");
559 }
560 } else {
561 if (hasRandomBits) {
562 return op->emitOpError(
563 "random_bits not supported for RN and RZ rounding modes.");
564 }
565 }
566
567 return success();
568}
569
570LogicalResult ConvertF32x2ToF16x2Op::verify() {
571 return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
572 getRandomBits() ? true : false, *this);
573}
574
575LogicalResult ConvertF32x2ToBF16x2Op::verify() {
576 return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
577 getRandomBits() ? true : false, *this);
578}
579
580LogicalResult ConvertF32x4ToF8x4Op::verify() {
582
583 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
584 return emitOpError("Only ")
585 << mlir::Float8E4M3FNType::get(ctx) << " and "
586 << mlir::Float8E5M2Type::get(ctx)
587 << " types are supported for conversions from f32x4 to f8x4.";
588
589 return success();
590}
591
592LogicalResult ConvertF32x4ToF6x4Op::verify() {
594
595 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
596 return emitOpError("Only ")
597 << mlir::Float6E2M3FNType::get(ctx) << " and "
598 << mlir::Float6E3M2FNType::get(ctx)
599 << " types are supported for conversions from f32x4 to f6x4.";
600
601 return success();
602}
603
604LogicalResult ConvertF32x4ToF4x4Op::verify() {
606
607 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
608 return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
609 << " type is supported for conversions from "
610 "f32x4 to f4x4.";
611
612 return success();
613}
614
615LogicalResult BulkStoreOp::verify() {
616 if (getInitVal() != 0)
617 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
618 return success();
619}
620
621LogicalResult PMEventOp::verify() {
622 auto eventId = getEventId();
623 auto maskedEventId = getMaskedEventId();
624 if (!maskedEventId && !eventId) {
625 return emitOpError() << "either `id` or `mask` must be set";
626 }
627
628 if (maskedEventId && eventId) {
629 return emitOpError() << "`id` and `mask` cannot be set at the same time";
630 }
631
632 if (eventId) {
633 if (eventId < 0 || eventId > 15) {
634 return emitOpError() << "`id` must be between 0 and 15";
635 }
636 }
637
638 return llvm::success();
639}
640
641// Given the element type of an operand and whether or not it is an accumulator,
642// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
643// operand's element type.
644std::optional<mlir::NVVM::MMATypes>
645MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
646 auto half2Type =
647 VectorType::get(2, Float16Type::get(operandElType.getContext()));
648 if (operandElType.isF64())
649 return NVVM::MMATypes::f64;
650 if (operandElType.isF16() || operandElType == half2Type)
651 return NVVM::MMATypes::f16;
652 if (operandElType.isF32() && isAccumulator)
653 return NVVM::MMATypes::f32;
654 if (operandElType.isF32() && !isAccumulator)
655 return NVVM::MMATypes::tf32;
656 if (llvm::isa<IntegerType>(operandElType)) {
657 if (isAccumulator)
658 return NVVM::MMATypes::s32;
659 return std::nullopt;
660 }
661
662 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
663 if (structType.getBody().empty())
664 return std::nullopt;
665 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
666 }
667
668 return std::nullopt;
669}
670
671static bool isInt4PtxType(MMATypes type) {
672 return (type == MMATypes::u4 || type == MMATypes::s4);
673}
674
675static bool isInt8PtxType(MMATypes type) {
676 return (type == MMATypes::u8 || type == MMATypes::s8);
677}
678
679static bool isIntegerPtxType(MMATypes type) {
680 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
681 type == MMATypes::s32;
682}
683
684MMATypes MmaOp::accumPtxType() {
685 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
686 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
687 assert(val.has_value() && "accumulator PTX type should always be inferrable");
688 return val.value();
689}
690
691MMATypes MmaOp::resultPtxType() {
692 std::optional<mlir::NVVM::MMATypes> val =
693 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
694 assert(val.has_value() && "result PTX type should always be inferrable");
695 return val.value();
696}
697
698void MmaOp::print(OpAsmPrinter &p) {
699 SmallVector<Type, 4> regTypes;
700 struct MMAOperandFragment {
701 StringRef operandName;
702 StringRef ptxTypeAttr;
703 SmallVector<Value, 4> regs;
704 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
705 : operandName(name), ptxTypeAttr(ptxTypeName) {}
706 };
707
708 std::array<MMAOperandFragment, 3> frags{
709 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
710 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
711 MMAOperandFragment("C", "")};
712 SmallVector<StringRef, 4> ignoreAttrNames{
713 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
714
715 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
716 auto &frag = frags[fragIdx];
717 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
718 for (auto operandIdx = varOperandSpec.first;
719 operandIdx < varOperandSpec.first + varOperandSpec.second;
720 operandIdx++) {
721 frag.regs.push_back(this->getOperand(operandIdx));
722 if (operandIdx == 0) {
723 regTypes.push_back(this->getOperand(operandIdx).getType());
724 }
725 }
726 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
727 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
728 if (inferredType)
729 ignoreAttrNames.push_back(frag.ptxTypeAttr);
730 }
731
732 auto printMmaOperand = [&](const MMAOperandFragment &frag) -> void {
733 p << " " << frag.operandName;
734 p << "[";
735 p.printOperands(frag.regs);
736 p << "] ";
737 };
738
739 for (const auto &frag : frags) {
740 printMmaOperand(frag);
741 }
742
743 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
744
745 // Print the types of the operands and result.
746 p << " : "
747 << "(";
748 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
749 frags[1].regs[0].getType(),
750 frags[2].regs[0].getType()},
751 p);
752 p << ")";
753 p.printArrowTypeList(TypeRange{this->getRes().getType()});
754}
755
756void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
757 ValueRange operandA, ValueRange operandB, ValueRange operandC,
758 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
759 std::optional<MMAIntOverflow> intOverflow,
760 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
761 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
762
763 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
764 MLIRContext *ctx = builder.getContext();
765 result.addAttribute(
766 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
767
768 result.addOperands(operandA);
769 result.addOperands(operandB);
770 result.addOperands(operandC);
771
772 if (multiplicandPtxTypes) {
773 result.addAttribute("multiplicandAPtxType",
774 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
775 result.addAttribute("multiplicandBPtxType",
776 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
777 } else {
778 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
779 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
780 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
781 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
782 }
783
784 if (multiplicandLayouts) {
785 result.addAttribute("layoutA",
786 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
787 result.addAttribute("layoutB",
788 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
789 } else {
790 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
791 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
792 }
793
794 if (intOverflow.has_value())
795 result.addAttribute("intOverflowBehavior",
796 MMAIntOverflowAttr::get(ctx, *intOverflow));
797 if (b1Op.has_value())
798 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
799
800 result.addTypes(resultType);
801 result.addAttribute(
802 MmaOp::getOperandSegmentSizeAttr(),
803 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
804 static_cast<int32_t>(operandB.size()),
805 static_cast<int32_t>(operandC.size())}));
806}
807
808// <operation> :=
809// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
810// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
811// `->` type($res)
812ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
813 struct MMAOperandFragment {
814 std::optional<MMATypes> elemtype;
815 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
816 SmallVector<Type> regTypes;
817 };
818
819 Builder &builder = parser.getBuilder();
820 std::array<MMAOperandFragment, 4> frags;
821
822 NamedAttrList namedAttributes;
823
824 // A helper to parse the operand segments.
825 auto parseMmaOperand = [&](StringRef operandName,
826 MMAOperandFragment &frag) -> LogicalResult {
827 if (parser.parseKeyword(operandName).failed())
828 return failure();
829 if (parser
830 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
831 .failed())
832 return failure();
833 return success();
834 };
835
836 // Parse the operand segments.
837 if (parseMmaOperand("A", frags[0]).failed())
838 return failure();
839 if (parseMmaOperand("B", frags[1]).failed())
840 return failure();
841 if (parseMmaOperand("C", frags[2]).failed())
842 return failure();
843
844 if (parser.parseOptionalAttrDict(namedAttributes).failed())
845 return failure();
846
847 // Parse the type specification and resolve operands.
848 SmallVector<Type, 3> operandTypes;
849 if (failed(parser.parseColon()))
850 return failure();
851 if (failed(parser.parseLParen()))
852 return failure();
853 if (failed(parser.parseTypeList(operandTypes)))
854 return failure();
855 if (failed(parser.parseRParen()))
856 if (operandTypes.size() != 3)
857 return parser.emitError(
858 parser.getNameLoc(),
859 "expected one type for each operand segment but got " +
860 Twine(operandTypes.size()) + " types");
861 for (const auto &iter : llvm::enumerate(operandTypes)) {
862 auto &frag = frags[iter.index()];
863 frag.regTypes.resize(frag.regs.size(), iter.value());
864 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
865 parser.getNameLoc(), result.operands)))
866 return failure();
867 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
868 /*isAccumulator*/ iter.index() < 2);
869 }
870
871 Type resultType;
872 if (parser.parseArrow() || parser.parseType(resultType))
873 return failure();
874 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
875
876 std::array<StringRef, 2> names{"multiplicandAPtxType",
877 "multiplicandBPtxType"};
878 for (unsigned idx = 0; idx < names.size(); idx++) {
879 const auto &frag = frags[idx];
880 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
881 if (!frag.elemtype.has_value() && !attr.has_value()) {
882 return parser.emitError(
883 parser.getNameLoc(),
884 "attribute " + names[idx] +
885 " is not provided explicitly and cannot be inferred");
886 }
887 if (!attr.has_value())
888 result.addAttribute(
889 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
890 }
891
892 result.addTypes(resultType);
893 if (!namedAttributes.empty())
894 result.addAttributes(namedAttributes);
895 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
896 builder.getDenseI32ArrayAttr({
897 static_cast<int32_t>(frags[0].regs.size()),
898 static_cast<int32_t>(frags[1].regs.size()),
899 static_cast<int32_t>(frags[2].regs.size()),
900 }));
901 return success();
902}
903
904LogicalResult MmaOp::verify() {
905 MLIRContext *context = getContext();
906 auto f16Ty = Float16Type::get(context);
907 auto i32Ty = IntegerType::get(context, 32);
908 auto f16x2Ty = VectorType::get(2, f16Ty);
909 auto f32Ty = Float32Type::get(context);
910 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
911 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
912
913 auto s32x4StructTy =
914 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
915 auto f32x8StructTy =
916 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
917 auto f16x2x2StructTy =
918 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
919 auto f32x4StructTy =
920 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
921 auto s32x2StructTy =
922 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
923
924 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
925 getShapeAttr().getK()};
926
927 // These variables define the set of allowed data types for matrices A, B, C,
928 // and result.
929 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
930 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
931 AllowedShapes allowedShapes;
932 AllowedTypes expectedA;
933 AllowedTypes expectedB;
934 AllowedTypes expectedC;
935 SmallVector<Type> expectedResult;
936
937 // When M = 16, we just need to calculate the number of 8xk tiles, where
938 // k is a factor that depends on the data type.
939 if (mmaShape[0] == 16) {
940 int64_t kFactor;
941 Type multiplicandFragType;
942 switch (*getMultiplicandAPtxType()) {
943 case MMATypes::tf32:
944 kFactor = 4;
945 multiplicandFragType = i32Ty;
946 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
947 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
948 break;
949 case MMATypes::bf16:
950 kFactor = 8;
951 multiplicandFragType = i32Ty;
952 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
953 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
954 break;
955 case MMATypes::f16:
956 kFactor = 8;
957 multiplicandFragType = f16x2Ty;
958 expectedResult.push_back(f16x2x2StructTy);
959 expectedResult.push_back(f32x4StructTy);
960 break;
961 case MMATypes::s4:
962 case MMATypes::u4:
963 kFactor = 32;
964 break;
965 case MMATypes::b1:
966 kFactor = 128;
967 break;
968 case MMATypes::s8:
969 case MMATypes::u8:
970 kFactor = 16;
971 break;
972 default:
973 return emitError("invalid shape or multiplicand type: ")
974 << getMultiplicandAPtxType().value();
975 }
976
977 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
978 expectedResult.push_back(s32x4StructTy);
979 expectedC.emplace_back(4, i32Ty);
980 multiplicandFragType = i32Ty;
981 } else {
982 expectedC.emplace_back(2, f16x2Ty);
983 expectedC.emplace_back(4, f32Ty);
984 }
985
986 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
987 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
988 expectedA.emplace_back(unitA, multiplicandFragType);
989 expectedB.emplace_back(unitB, multiplicandFragType);
990 allowedShapes.push_back({16, 8, kFactor});
991 allowedShapes.push_back({16, 8, kFactor * 2});
992
993 if (resultPtxType() != accumPtxType())
994 return emitOpError("ctype does not match dtype");
995 }
996
997 // In the M=8 case, there is only 1 possible case per data type.
998 if (mmaShape[0] == 8) {
999 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1000 expectedA.emplace_back(2, f16x2Ty);
1001 expectedB.emplace_back(2, f16x2Ty);
1002 expectedResult.push_back(f16x2x4StructTy);
1003 expectedResult.push_back(f32x8StructTy);
1004 expectedC.emplace_back(4, f16x2Ty);
1005 expectedC.emplace_back(8, f32Ty);
1006 allowedShapes.push_back({8, 8, 4});
1007 }
1008 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1009 Type f64Ty = Float64Type::get(context);
1010 expectedA.emplace_back(1, f64Ty);
1011 expectedB.emplace_back(1, f64Ty);
1012 expectedC.emplace_back(2, f64Ty);
1013 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1014 context, SmallVector<Type>(2, f64Ty)));
1015 allowedShapes.push_back({8, 8, 4});
1016 }
1017 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1018 expectedA.push_back({i32Ty});
1019 expectedB.push_back({i32Ty});
1020 expectedC.push_back({i32Ty, i32Ty});
1021 expectedResult.push_back(s32x2StructTy);
1022 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1023 allowedShapes.push_back({8, 8, 32});
1024 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1025 allowedShapes.push_back({8, 8, 16});
1026 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1027 allowedShapes.push_back({8, 8, 128});
1028 }
1029 }
1030
1031 std::string errorMessage;
1032 llvm::raw_string_ostream errorStream(errorMessage);
1033
1034 // Check that we matched an existing shape/dtype combination.
1035 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1036 !llvm::is_contained(allowedShapes, mmaShape)) {
1037 errorStream << "unimplemented variant for MMA shape <";
1038 llvm::interleaveComma(mmaShape, errorStream);
1039 errorStream << ">";
1040 return emitOpError(errorMessage);
1041 }
1042
1043 // Verify the operand types for segments of A, B, and C operands.
1044 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1045 for (const auto &iter : llvm::enumerate(
1046 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1047 auto spec = this->getODSOperandIndexAndLength(iter.index());
1048 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1049 operand_type_begin() + spec.first +
1050 spec.second);
1051 bool match = llvm::is_contained(iter.value(), operandTySeg);
1052
1053 if (!match) {
1054 errorStream << "Could not match types for the "
1055 << operandNames[iter.index()]
1056 << " operands; expected one of ";
1057 for (const auto &x : iter.value()) {
1058 errorStream << x.size() << "x" << x[0] << " ";
1059 }
1060 errorStream << "but got ";
1061 llvm::interleaveComma(operandTySeg, errorStream);
1062 return emitOpError(errorMessage);
1063 }
1064 }
1065
1066 // Check the result type
1067 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1068 return expectedResultType == getResult().getType();
1069 })) {
1070 errorStream
1071 << "Could not match allowed types for the result; expected one of ";
1072 llvm::interleaveComma(expectedResult, errorStream);
1073 errorStream << " but got " << getResult().getType();
1074 return emitOpError(errorMessage);
1075 }
1076
1077 // Ensure that binary MMA variants have a b1 MMA operation defined.
1078 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1079 return emitOpError("op requires " + getB1OpAttrName().strref() +
1080 " attribute");
1081 }
1082
1083 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1084 // attribute.
1085 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1086 isInt8PtxType(*getMultiplicandAPtxType())) {
1087 if (!getIntOverflowBehavior())
1088 return emitOpError("op requires " +
1089 getIntOverflowBehaviorAttrName().strref() +
1090 " attribute");
1091 }
1092
1093 // Validate layout combinations. According to the operation description, most
1094 // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
1095 // can use other layout combinations.
1096 bool isM8N8K4_F16 =
1097 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1098 getMultiplicandAPtxType() == MMATypes::f16);
1099
1100 if (!isM8N8K4_F16) {
1101 // For all other shapes/types, layoutA must be row and layoutB must be col
1102 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1103 return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
1104 "layoutB = #nvvm.mma_layout<col> for shape <")
1105 << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
1106 << "> with element types " << *getMultiplicandAPtxType() << " and "
1107 << *getMultiplicandBPtxType()
1108 << ". Only m8n8k4 with f16 supports other layouts.";
1109 }
1110 }
1111
1112 return success();
1113}
1114
1115MMATypes MmaSpOp::accumPtxType() {
1116 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1117 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
1118 assert(val.has_value() && "accumulator PTX type should always be inferrable");
1119 return val.value();
1120}
1121
1122MMATypes MmaSpOp::resultPtxType() {
1123 std::optional<mlir::NVVM::MMATypes> val =
1124 MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
1125 assert(val.has_value() && "result PTX type should always be inferrable");
1126 return val.value();
1127}
1128
1130MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1131 llvm::IRBuilderBase &builder) {
1132 auto thisOp = cast<NVVM::MmaSpOp>(op);
1133
1134 // Get operands
1136 for (mlir::Value v : thisOp.getOperands())
1137 args.push_back(mt.lookupValue(v));
1138
1139 // Get intrinsic ID using the existing getIntrinsicID method
1140 auto intId = MmaSpOp::getIntrinsicID(
1141 thisOp.getShape().getM(), thisOp.getShape().getN(),
1142 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1143 thisOp.getOrderedMetadata(), thisOp.getKind(),
1144 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1145 thisOp.accumPtxType(), thisOp.resultPtxType());
1146
1147 return {intId, args};
1148}
1149
1150void MmaSpOp::print(OpAsmPrinter &p) {
1151 SmallVector<Type, 4> regTypes;
1152 struct MMAOperandFragment {
1153 StringRef operandName;
1154 StringRef ptxTypeAttr;
1155 SmallVector<Value, 4> regs;
1156 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1157 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1158 };
1159
1160 std::array<MMAOperandFragment, 5> frags{
1161 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1162 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1163 MMAOperandFragment("C", ""), MMAOperandFragment("sparseMetadata", ""),
1164 MMAOperandFragment("selector", "")};
1165 SmallVector<StringRef, 4> ignoreAttrNames{
1166 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1167
1168 // Handle variadic operands A, B, C
1169 for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1170 auto &frag = frags[fragIdx];
1171 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1172 for (auto operandIdx = varOperandSpec.first;
1173 operandIdx < varOperandSpec.first + varOperandSpec.second;
1174 operandIdx++) {
1175 frag.regs.push_back(this->getOperand(operandIdx));
1176 if (operandIdx == varOperandSpec.first) {
1177 regTypes.push_back(this->getOperand(operandIdx).getType());
1178 }
1179 }
1180 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1181 regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
1182 if (inferredType)
1183 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1184 }
1185
1186 // Handle sparse metadata and selector (single operands)
1187 frags[3].regs.push_back(getSparseMetadata());
1188 frags[4].regs.push_back(getSparsitySelector());
1189
1190 auto printMmaSpOperand = [&](const MMAOperandFragment &frag) -> void {
1191 p << " " << frag.operandName;
1192 p << "[";
1193 p.printOperands(frag.regs);
1194 p << "]";
1195 };
1196
1197 for (const auto &frag : frags)
1198 printMmaSpOperand(frag);
1199
1200 p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
1201 p << " : ";
1202 p << "(";
1203 for (int i = 0; i < 3; ++i) {
1204 p << regTypes[i];
1205 if (i < 2)
1206 p << ", ";
1207 }
1208 p << ") -> " << getResult().getType();
1209}
1210
1211void MmaSpOp::build(
1212 OpBuilder &builder, OperationState &result, Type resultType,
1213 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1214 Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
1215 std::optional<MMAIntOverflow> intOverflow,
1216 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1217
1218 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1219 MLIRContext *ctx = builder.getContext();
1220 result.addAttribute(
1221 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1222
1223 result.addOperands(operandA);
1224 result.addOperands(operandB);
1225 result.addOperands(operandC);
1226 result.addOperands(sparseMetadata);
1227 result.addOperands(sparsitySelector);
1228
1229 if (multiplicandPtxTypes) {
1230 result.addAttribute("multiplicandAPtxType",
1231 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1232 result.addAttribute("multiplicandBPtxType",
1233 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1234 } else {
1235 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1236 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1237 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1238 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1239 }
1240
1241 if (intOverflow.has_value())
1242 result.addAttribute("intOverflowBehavior",
1243 MMAIntOverflowAttr::get(ctx, *intOverflow));
1244
1245 result.addTypes(resultType);
1246 result.addAttribute(
1247 MmaSpOp::getOperandSegmentSizeAttr(),
1248 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
1249 static_cast<int32_t>(operandB.size()),
1250 static_cast<int32_t>(operandC.size()), 1,
1251 1})); // sparseMetadata and sparsitySelector
1252}
1253
1254ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
1255 struct MMAOperandFragment {
1256 std::optional<MMATypes> elemtype;
1257 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1258 SmallVector<Type> regTypes;
1259 };
1260
1261 Builder &builder = parser.getBuilder();
1262 std::array<MMAOperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
1263
1264 NamedAttrList namedAttributes;
1265
1266 // A helper to parse the operand segments.
1267 auto parseMmaSpOperand = [&](StringRef operandName,
1268 MMAOperandFragment &frag) -> LogicalResult {
1269 if (parser.parseKeyword(operandName).failed())
1270 return failure();
1271 if (parser
1272 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
1273 .failed())
1274 return failure();
1275 return success();
1276 };
1277
1278 // Parse the operand segments.
1279 if (parseMmaSpOperand("A", frags[0]).failed())
1280 return failure();
1281 if (parseMmaSpOperand("B", frags[1]).failed())
1282 return failure();
1283 if (parseMmaSpOperand("C", frags[2]).failed())
1284 return failure();
1285 if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
1286 return failure();
1287 if (parseMmaSpOperand("selector", frags[4]).failed())
1288 return failure();
1289
1290 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1291 return failure();
1292
1293 // Parse the type specification and resolve operands.
1294 SmallVector<Type, 3> operandTypes;
1295 if (failed(parser.parseColon()))
1296 return failure();
1297 if (failed(parser.parseLParen()))
1298 return failure();
1299 if (failed(parser.parseTypeList(operandTypes)))
1300 return failure();
1301 if (failed(parser.parseRParen()))
1302 return failure();
1303 if (operandTypes.size() != 3)
1304 return parser.emitError(
1305 parser.getNameLoc(),
1306 "expected one type for each operand segment but got " +
1307 Twine(operandTypes.size()) + " types");
1308 for (const auto &iter : llvm::enumerate(operandTypes)) {
1309 auto &frag = frags[iter.index()];
1310 frag.regTypes.resize(frag.regs.size(), iter.value());
1311 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
1312 parser.getNameLoc(), result.operands)))
1313 return failure();
1314 frag.elemtype =
1315 MmaOp::inferOperandMMAType(frag.regTypes[0],
1316 /*isAccumulator*/ iter.index() >= 2);
1317 }
1318
1319 Type resultType;
1320 if (parser.parseArrow() || parser.parseType(resultType))
1321 return failure();
1322 frags[5].elemtype =
1323 MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
1324
1325 // Resolve sparse metadata and selector (assume i32 type)
1326 Type i32Type = builder.getIntegerType(32);
1327 if (parser
1328 .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
1329 result.operands)
1330 .failed())
1331 return failure();
1332 if (parser
1333 .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
1334 result.operands)
1335 .failed())
1336 return failure();
1337
1338 std::array<StringRef, 2> names{"multiplicandAPtxType",
1339 "multiplicandBPtxType"};
1340 for (unsigned idx = 0; idx < names.size(); idx++) {
1341 const auto &frag = frags[idx];
1342 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
1343 if (!frag.elemtype.has_value() && !attr.has_value()) {
1344 return parser.emitError(
1345 parser.getNameLoc(),
1346 "attribute " + names[idx] +
1347 " is not provided explicitly and cannot be inferred");
1348 }
1349 if (!attr.has_value())
1350 result.addAttribute(
1351 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
1352 }
1353
1354 result.addTypes(resultType);
1355 if (!namedAttributes.empty())
1356 result.addAttributes(namedAttributes);
1357 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1358 builder.getDenseI32ArrayAttr({
1359 static_cast<int32_t>(frags[0].regs.size()),
1360 static_cast<int32_t>(frags[1].regs.size()),
1361 static_cast<int32_t>(frags[2].regs.size()),
1362 1, // sparseMetadata
1363 1 // sparsitySelector
1364 }));
1365 return success();
1366}
1367
1368LogicalResult MmaSpOp::verify() {
1369 MLIRContext *context = getContext();
1370 auto f16Ty = Float16Type::get(context);
1371 auto i32Ty = IntegerType::get(context, 32);
1372 auto f16x2Ty = VectorType::get(2, f16Ty);
1373 auto f32Ty = Float32Type::get(context);
1374 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1375 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1376
1377 auto s32x4StructTy =
1378 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1379 auto f32x8StructTy =
1380 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
1381 auto f16x2x2StructTy =
1382 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1383 auto f32x4StructTy =
1384 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1385 auto s32x2StructTy =
1386 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1387
1388 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1389 getShapeAttr().getK()};
1390
1391 // These variables define the set of allowed data types for matrices A, B, C,
1392 // and result.
1393 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
1394 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
1395 AllowedShapes allowedShapes;
1396 AllowedTypes expectedA;
1397 AllowedTypes expectedB;
1398 AllowedTypes expectedC;
1399 SmallVector<Type> expectedResult;
1400
1401 // When M = 16, we just need to calculate the number of 8xk tiles, where
1402 // k is a factor that depends on the data type.
1403 if (mmaShape[0] == 16) {
1404 int64_t kFactor;
1405 Type multiplicandFragType;
1406 switch (*getMultiplicandAPtxType()) {
1407 case MMATypes::tf32:
1408 kFactor = 4;
1409 multiplicandFragType = i32Ty;
1410 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1411 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1412 // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
1413 allowedShapes.push_back({16, 8, 8});
1414 allowedShapes.push_back({16, 8, 16});
1415 break;
1416 case MMATypes::bf16:
1417 kFactor = 8;
1418 multiplicandFragType = i32Ty;
1419 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1420 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1421 // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
1422 allowedShapes.push_back({16, 8, 16});
1423 allowedShapes.push_back({16, 8, 32});
1424 break;
1425 case MMATypes::f16:
1426 kFactor = 8;
1427 multiplicandFragType = f16x2Ty;
1428 expectedResult.push_back(f16x2x2StructTy);
1429 expectedResult.push_back(f32x4StructTy);
1430 // Sparse MMA supports m16n8k16 and m16n8k32 for f16
1431 allowedShapes.push_back({16, 8, 16});
1432 allowedShapes.push_back({16, 8, 32});
1433 break;
1434 case MMATypes::s4:
1435 case MMATypes::u4:
1436 kFactor = 32;
1437 // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
1438 allowedShapes.push_back({16, 8, 64});
1439 allowedShapes.push_back({16, 8, 128});
1440 break;
1441 case MMATypes::s8:
1442 case MMATypes::u8:
1443 kFactor = 16;
1444 // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
1445 allowedShapes.push_back({16, 8, 32});
1446 allowedShapes.push_back({16, 8, 64});
1447 break;
1448 case MMATypes::e4m3:
1449 case MMATypes::e5m2:
1450 case MMATypes::e3m2:
1451 case MMATypes::e2m3:
1452 case MMATypes::e2m1:
1453 kFactor = 16;
1454 multiplicandFragType = i32Ty;
1455 expectedResult.push_back(f16x2x2StructTy);
1456 expectedResult.push_back(f32x4StructTy);
1457 // Sparse MMA supports m16n8k64 for FP8 types
1458 allowedShapes.push_back({16, 8, 64});
1459 break;
1460 default:
1461 return emitError("invalid shape or multiplicand type: ")
1462 << getMultiplicandAPtxType().value();
1463 }
1464
1465 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1466 expectedResult.push_back(s32x4StructTy);
1467 expectedC.emplace_back(4, i32Ty);
1468 multiplicandFragType = i32Ty;
1469 } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1470 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1471 // FP8 types
1472 expectedC.emplace_back(2, f16x2Ty);
1473 expectedC.emplace_back(4, f32Ty);
1474 } else {
1475 expectedC.emplace_back(2, f16x2Ty);
1476 expectedC.emplace_back(4, f32Ty);
1477 }
1478
1479 // For sparse MMA, A operand is compressed (2:4 sparsity means half the
1480 // elements)
1481 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1482 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1483 expectedA.emplace_back(unitA, multiplicandFragType);
1484 expectedB.emplace_back(unitB, multiplicandFragType);
1485
1486 if (resultPtxType() != accumPtxType())
1487 return emitOpError("ctype does not match dtype");
1488 }
1489
1490 // In the M=8 case, there is only 1 possible case per data type.
1491 if (mmaShape[0] == 8) {
1492 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1493 expectedA.emplace_back(2, f16x2Ty);
1494 expectedB.emplace_back(2, f16x2Ty);
1495 expectedResult.push_back(f16x2x4StructTy);
1496 expectedResult.push_back(f32x8StructTy);
1497 expectedC.emplace_back(4, f16x2Ty);
1498 expectedC.emplace_back(8, f32Ty);
1499 allowedShapes.push_back({8, 8, 4});
1500 }
1501 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1502 Type f64Ty = Float64Type::get(context);
1503 expectedA.emplace_back(1, f64Ty);
1504 expectedB.emplace_back(1, f64Ty);
1505 expectedC.emplace_back(2, f64Ty);
1506 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1507 context, SmallVector<Type>(2, f64Ty)));
1508 allowedShapes.push_back({8, 8, 4});
1509 }
1510 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
1511 expectedA.push_back({i32Ty});
1512 expectedB.push_back({i32Ty});
1513 expectedC.push_back({i32Ty, i32Ty});
1514 expectedResult.push_back(s32x2StructTy);
1515 if (isInt4PtxType(getMultiplicandAPtxType().value()))
1516 allowedShapes.push_back({8, 8, 32});
1517 if (isInt8PtxType(getMultiplicandAPtxType().value()))
1518 allowedShapes.push_back({8, 8, 16});
1519 }
1520 }
1521
1522 std::string errorMessage;
1523 llvm::raw_string_ostream errorStream(errorMessage);
1524
1525 // Check that we matched an existing shape/dtype combination.
1526 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1527 !llvm::is_contained(allowedShapes, mmaShape)) {
1528 errorStream << "unimplemented variant for MMA shape <";
1529 llvm::interleaveComma(mmaShape, errorStream);
1530 errorStream << ">";
1531 return emitOpError(errorMessage);
1532 }
1533
1534 // Verify the operand types for segments of A, B, and C operands.
1535 std::array<StringRef, 3> operandNames{"A", "B", "C"};
1536 for (const auto &iter : llvm::enumerate(
1537 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
1538 auto spec = this->getODSOperandIndexAndLength(iter.index());
1539 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
1540 operand_type_begin() + spec.first +
1541 spec.second);
1542 bool match = llvm::is_contained(iter.value(), operandTySeg);
1543
1544 if (!match) {
1545 errorStream << "Could not match types for the "
1546 << operandNames[iter.index()]
1547 << " operands; expected one of ";
1548 for (const auto &x : iter.value()) {
1549 errorStream << x.size() << "x" << x[0] << " ";
1550 }
1551 errorStream << "but got ";
1552 llvm::interleaveComma(operandTySeg, errorStream);
1553 return emitOpError(errorMessage);
1554 }
1555 }
1556
1557 // Check the result type
1558 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
1559 return expectedResultType == getResult().getType();
1560 })) {
1561 errorStream
1562 << "Could not match allowed types for the result; expected one of ";
1563 llvm::interleaveComma(expectedResult, errorStream);
1564 errorStream << " but got " << getResult().getType();
1565 return emitOpError(errorMessage);
1566 }
1567
1568 // Ensure int4/int8 MMA variants specify the accum overflow behavior
1569 // attribute.
1570 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
1571 isInt8PtxType(*getMultiplicandAPtxType())) {
1572 if (!getIntOverflowBehavior())
1573 return emitOpError("op requires " +
1574 getIntOverflowBehaviorAttrName().strref() +
1575 " attribute");
1576 }
1577
1578 // Validate sparse metadata type (should be i32)
1579 if (!getSparseMetadata().getType().isInteger(32)) {
1580 return emitOpError() << "sparse metadata must be i32 type";
1581 }
1582
1583 // Validate sparsity selector type (should be i32)
1584 if (!getSparsitySelector().getType().isInteger(32)) {
1585 return emitOpError() << "sparsity selector must be i32 type";
1586 }
1587
1588 return success();
1589}
1590
1591//===----------------------------------------------------------------------===//
1592// MMA Block Scale Operations - Shared Helpers
1593//===----------------------------------------------------------------------===//
1594
1595namespace {
1596// Shared structure for MMA operand fragments (A, B, C)
1597struct MMAOperandFragment {
1598 StringRef operandName;
1599 StringRef ptxTypeAttr;
1600 SmallVector<Value, 4> regs;
1601 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1602 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1603};
1604} // namespace
1605
1606// Helper to print operand list in the format: name[operands]
1607static void printOperandList(OpAsmPrinter &p, StringRef name,
1608 ArrayRef<Value> operands) {
1609 p << " " << name << "[";
1610 p.printOperands(operands);
1611 p << "]";
1612}
1613
1614// Helper to parse operand list in the format: name[operands]
1615static LogicalResult
1616parseMmaOperand(OpAsmParser &parser, StringRef operandName,
1618 if (parser.parseKeyword(operandName).failed())
1619 return failure();
1621 .failed())
1622 return failure();
1623 return success();
1624}
1625
1626// Helper to process operand fragments and determine which attributes can be
1627// inferred
1628template <typename Op>
1629static void
1630processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
1631 SmallVectorImpl<Type> &regTypes,
1632 SmallVectorImpl<StringRef> &ignoreAttrNames) {
1633 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1634 auto &frag = frags[fragIdx];
1635 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1636 for (auto operandIdx = varOperandSpec.first;
1637 operandIdx < varOperandSpec.first + varOperandSpec.second;
1638 operandIdx++) {
1639 frag.regs.push_back(op.getOperand(operandIdx));
1640 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1641 regTypes.push_back(op.getOperand(operandIdx).getType());
1642 }
1643 }
1644 if (fragIdx < 2) {
1645 regTypes.push_back(frag.regs[0].getType());
1646 }
1647 std::optional<MMATypes> inferredType =
1648 MmaOp::inferOperandMMAType(regTypes.back(),
1649 /*isAccumulator=*/fragIdx >= 2);
1650 if (inferredType)
1651 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1652 }
1653}
1654
1655// Helper to parse type signature: (A_type, B_type, C_type)
1656static LogicalResult
1658 SmallVectorImpl<Type> &operandTypes) {
1659 if (parser.parseColon().failed() || parser.parseLParen().failed())
1660 return failure();
1661
1662 auto typeParser = [&]() {
1663 Type ty;
1664 if (parser.parseType(ty).failed())
1665 return failure();
1666 operandTypes.push_back(ty);
1667 return success();
1668 };
1669 if (parser.parseCommaSeparatedList(typeParser))
1670 return failure();
1671
1672 if (operandTypes.size() != 3)
1673 return parser.emitError(parser.getCurrentLocation(),
1674 "expected exactly 3 types");
1675
1676 return parser.parseRParen();
1677}
1678
1679// Helper to infer and set multiplicand PTX type attributes
1680static void
1682 const SmallVectorImpl<Type> &operandTypes) {
1683 if (!attrs.get("multiplicandAPtxType")) {
1684 if (auto inferredType =
1685 MmaOp::inferOperandMMAType(operandTypes[0], false)) {
1686 attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1687 }
1688 }
1689 if (!attrs.get("multiplicandBPtxType")) {
1690 if (auto inferredType =
1691 MmaOp::inferOperandMMAType(operandTypes[1], false)) {
1692 attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1693 }
1694 }
1695}
1696
1697// Helper to add common block scale properties
1698template <typename OpType>
1701 ScaleVecSize scaleVecSize,
1702 BlockScaleFormat blockScaleFormat,
1703 MMABlockScaleKind kind) {
1704 MLIRContext *ctx = builder.getContext();
1705 auto &properties = result.getOrAddProperties<typename OpType::Properties>();
1706 properties.setShape(
1707 builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
1708 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1709 properties.setBlockScaleFormat(
1710 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1711 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1712}
1713
1714// Helper to infer and add multiplicand PTX types to builder
1717 ValueRange operandB,
1718 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1719 if (multiplicandPtxTypes) {
1720 result.addAttribute("multiplicandAPtxType",
1721 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1722 result.addAttribute("multiplicandBPtxType",
1723 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1724 } else {
1725 if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
1726 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1727 if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
1728 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1729 }
1730}
1731
1732// Template helper for common accumPtxType/resultPtxType implementation
1733template <typename OpTy>
1734static MMATypes inferPtxTypeFromResult(OpTy op) {
1735 return *MmaOp::inferOperandMMAType(
1736 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1737 /*isAccumulator=*/true);
1738}
1739
1740//===----------------------------------------------------------------------===//
1741// MmaBlockScaleOp
1742//===----------------------------------------------------------------------===//
1743
1744void MmaBlockScaleOp::print(OpAsmPrinter &p) {
1745 SmallVector<Type, 4> regTypes;
1746 std::array<MMAOperandFragment, 3> frags{
1747 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1748 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1749 MMAOperandFragment("C", "")};
1750 SmallVector<StringRef, 4> ignoreAttrNames{
1751 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1752
1753 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1754
1755 // Print A, B, C operands
1756 for (const auto &frag : frags)
1757 printOperandList(p, frag.operandName, frag.regs);
1758
1759 // Print scale operands
1760 printOperandList(p, "scaleA",
1761 {getScaleAData(), getByteIdA(), getThreadIdA()});
1762 printOperandList(p, "scaleB",
1763 {getScaleBData(), getByteIdB(), getThreadIdB()});
1764
1765 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1766
1767 // Print type signature
1768 p << " : (";
1769 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
1770 frags[1].regs[0].getType(),
1771 frags[2].regs[0].getType()},
1772 p);
1773 p << ")";
1774 p.printArrowTypeList(TypeRange{this->getRes().getType()});
1775}
1776
1777ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
1779 struct LocalOperandFragment {
1780 std::optional<MMATypes> elemtype;
1781 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1782 };
1783
1784 Builder &builder = parser.getBuilder();
1785 std::array<LocalOperandFragment, 3> frags;
1786 NamedAttrList namedAttributes;
1787
1788 // Parse A[...] B[...] C[...]
1789 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
1790 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
1791 parseMmaOperand(parser, "C", frags[2].regs).failed())
1792 return failure();
1793
1794 // Parse scale operands: scaleA[...] scaleB[...]
1795 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
1796 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
1797 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
1798 return failure();
1799
1800 if (parser.parseOptionalAttrDict(namedAttributes).failed())
1801 return failure();
1802
1803 // Parse type signature
1804 SmallVector<Type, 3> operandTypes;
1805 if (parseMmaTypeSignature(parser, operandTypes).failed())
1806 return failure();
1807
1808 // Parse result type
1809 SmallVector<Type, 1> resultTypes;
1810 if (parser.parseArrowTypeList(resultTypes).failed())
1811 return failure();
1812
1813 // Infer element types and resolve operands
1814 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
1815 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1816 /*isAccumulator=*/idx >= 2);
1817 if (parser
1818 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
1819 result.operands)
1820 .failed())
1821 return failure();
1822 }
1823
1824 // Resolve scale operands
1825 SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
1826 builder.getI16Type()};
1827 if (parser
1828 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
1829 result.operands)
1830 .failed() ||
1831 parser
1832 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
1833 result.operands)
1834 .failed())
1835 return failure();
1836
1837 // Add attributes
1838 result.addAttributes(namedAttributes);
1839 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
1840 operandTypes);
1841
1842 result.addTypes(resultTypes);
1843 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1844 builder.getDenseI32ArrayAttr({
1845 static_cast<int32_t>(frags[0].regs.size()),
1846 static_cast<int32_t>(frags[1].regs.size()),
1847 static_cast<int32_t>(frags[2].regs.size()),
1848 1, // scaleAData
1849 1, // byteIdA
1850 1, // threadIdA
1851 1, // scaleBData
1852 1, // byteIdB
1853 1 // threadIdB
1854 }));
1855 return success();
1856}
1857
1858void MmaBlockScaleOp::build(
1859 OpBuilder &builder, OperationState &result, Type resultType,
1860 ValueRange operandA, ValueRange operandB, ValueRange operandC,
1861 Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
1862 Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
1863 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1864 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1865 MMABlockScaleKind kind) {
1866 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
1867
1869 blockScaleFormat, kind);
1870
1871 result.addOperands(operandA);
1872 result.addOperands(operandB);
1873 result.addOperands(operandC);
1874 result.addOperands(
1875 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1876
1877 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
1878 multiplicandPtxTypes);
1879
1880 result.addTypes(resultType);
1881 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1882 builder.getDenseI32ArrayAttr({
1883 static_cast<int32_t>(operandA.size()),
1884 static_cast<int32_t>(operandB.size()),
1885 static_cast<int32_t>(operandC.size()),
1886 1, // scaleAData
1887 1, // byteIdA
1888 1, // threadIdA
1889 1, // scaleBData
1890 1, // byteIdB
1891 1 // threadIdB
1892 }));
1893}
1894
1895NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
1896 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1897 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1898
1900 // Add A, B, C operands
1901 for (Value operand : curOp.getOperandA())
1902 args.push_back(mt.lookupValue(operand));
1903 for (Value operand : curOp.getOperandB())
1904 args.push_back(mt.lookupValue(operand));
1905 for (Value operand : curOp.getOperandC())
1906 args.push_back(mt.lookupValue(operand));
1907
1908 // Add scale operands
1909 args.push_back(mt.lookupValue(curOp.getScaleAData()));
1910 args.push_back(mt.lookupValue(curOp.getByteIdA()));
1911 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
1912 args.push_back(mt.lookupValue(curOp.getScaleBData()));
1913 args.push_back(mt.lookupValue(curOp.getByteIdB()));
1914 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
1915
1916 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1917 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1918 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1919 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
1920 curOp.getBlockScaleFormat(), curOp.getKind());
1921
1922 return {intId, args};
1923}
1924
1925LogicalResult MmaBlockScaleOp::verify() {
1926 LogicalResult result = success();
1927 int m = getShape().getM();
1928 int n = getShape().getN();
1929 int k = getShape().getK();
1930
1931 if (m == 16 && n == 8 && k == 64) {
1932 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1933 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1935 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1936 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1937 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1939 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1940 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1942 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1943 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1944 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1945 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1946 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1947 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
1948 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
1949 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
1950 "attributes for mma.m16n8k64.mxf4nvf4");
1951 } else {
1952 result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
1953 }
1954 } else if (m == 16 && n == 8 && k == 32) {
1955 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
1956 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
1957 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
1958 result =
1959 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
1960 "attributes for mma.m16n8k32");
1961 } else {
1962 result = emitOpError("unsupported Geom for mma with block scaling");
1963 }
1964 return result;
1965}
1966
1967//===----------------------------------------------------------------------===//
1968// MmaSpBlockScaleOp
1969//===----------------------------------------------------------------------===//
1970
1971void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
1972 SmallVector<Type, 4> regTypes;
1973 std::array<MMAOperandFragment, 3> frags{
1974 MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
1975 MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
1976 MMAOperandFragment("C", "")};
1977 SmallVector<StringRef, 4> ignoreAttrNames{
1978 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
1979
1980 processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
1981
1982 // Print A, B, C operands
1983 for (const auto &frag : frags)
1984 printOperandList(p, frag.operandName, frag.regs);
1985
1986 // Print sparse-specific operands
1987 printOperandList(p, "sparseMetadata", {getSparseMetadata()});
1988 printOperandList(p, "selector", {getSparsitySelector()});
1989
1990 // Print scale operands
1991 printOperandList(p, "scaleA",
1992 {getScaleAData(), getByteIdA(), getThreadIdA()});
1993 printOperandList(p, "scaleB",
1994 {getScaleBData(), getByteIdB(), getThreadIdB()});
1995
1996 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
1997
1998 // Print type signature
1999 p << " : (";
2000 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
2001 frags[1].regs[0].getType(),
2002 frags[2].regs[0].getType()},
2003 p);
2004 p << ")";
2005 p.printArrowTypeList(TypeRange{this->getRes().getType()});
2006}
2007
2008ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
2010 struct LocalOperandFragment {
2011 std::optional<MMATypes> elemtype;
2012 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
2013 };
2014
2015 Builder &builder = parser.getBuilder();
2016 std::array<LocalOperandFragment, 3> frags;
2017 NamedAttrList namedAttributes;
2018
2019 // Parse A[...] B[...] C[...]
2020 if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
2021 parseMmaOperand(parser, "B", frags[1].regs).failed() ||
2022 parseMmaOperand(parser, "C", frags[2].regs).failed())
2023 return failure();
2024
2025 // Parse sparse-specific operands
2027 selectorOperands;
2028 if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
2029 parseMmaOperand(parser, "selector", selectorOperands).failed())
2030 return failure();
2031
2032 // Parse scale operands
2033 SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
2034 if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
2035 parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
2036 return failure();
2037
2038 if (parser.parseOptionalAttrDict(namedAttributes).failed())
2039 return failure();
2040
2041 // Parse type signature
2042 SmallVector<Type, 3> operandTypes;
2043 if (parseMmaTypeSignature(parser, operandTypes).failed())
2044 return failure();
2045
2046 // Parse result type
2047 SmallVector<Type, 1> resultTypes;
2048 if (parser.parseArrowTypeList(resultTypes).failed())
2049 return failure();
2050
2051 // Infer element types and resolve operands
2052 for (const auto &[idx, frag] : llvm::enumerate(frags)) {
2053 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2054 /*isAccumulator=*/idx >= 2);
2055 if (parser
2056 .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
2057 result.operands)
2058 .failed())
2059 return failure();
2060 }
2061
2062 // Resolve sparse metadata and selector
2063 Type i32Type = builder.getI32Type();
2064 if (parser
2065 .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
2066 result.operands)
2067 .failed() ||
2068 parser
2069 .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
2070 result.operands)
2071 .failed())
2072 return failure();
2073
2074 // Resolve scale operands
2075 SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
2076 builder.getI16Type()};
2077 if (parser
2078 .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
2079 result.operands)
2080 .failed() ||
2081 parser
2082 .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
2083 result.operands)
2084 .failed())
2085 return failure();
2086
2087 // Add attributes
2088 result.addAttributes(namedAttributes);
2089 inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
2090 operandTypes);
2091
2092 // orderedMetadata is mandatory
2093 if (!result.attributes.get("orderedMetadata"))
2094 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2095
2096 result.addTypes(resultTypes);
2097 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2098 builder.getDenseI32ArrayAttr({
2099 static_cast<int32_t>(frags[0].regs.size()),
2100 static_cast<int32_t>(frags[1].regs.size()),
2101 static_cast<int32_t>(frags[2].regs.size()),
2102 1, // sparseMetadata
2103 1, // sparsitySelector
2104 1, // scaleAData
2105 1, // byteIdA
2106 1, // threadIdA
2107 1, // scaleBData
2108 1, // byteIdB
2109 1 // threadIdB
2110 }));
2111 return success();
2112}
2113
2114void MmaSpBlockScaleOp::build(
2115 OpBuilder &builder, OperationState &result, Type resultType,
2116 ValueRange operandA, ValueRange operandB, ValueRange operandC,
2117 Value sparseMetadata, Value sparsitySelector, Value scaleAData,
2118 Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
2119 Value threadIdB, ArrayRef<int64_t> shape,
2120 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2121 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2122 MMABlockScaleKind kind) {
2123 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
2124
2126 builder, result, shape, scaleVecSize, blockScaleFormat, kind);
2127 result.addAttribute("orderedMetadata", builder.getUnitAttr());
2128
2129 result.addOperands(operandA);
2130 result.addOperands(operandB);
2131 result.addOperands(operandC);
2132 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2133 threadIdA, scaleBData, byteIdB, threadIdB});
2134
2135 addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
2136 multiplicandPtxTypes);
2137
2138 result.addTypes(resultType);
2139 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2140 builder.getDenseI32ArrayAttr({
2141 static_cast<int32_t>(operandA.size()),
2142 static_cast<int32_t>(operandB.size()),
2143 static_cast<int32_t>(operandC.size()),
2144 1, // sparseMetadata
2145 1, // sparsitySelector
2146 1, // scaleAData
2147 1, // byteIdA
2148 1, // threadIdA
2149 1, // scaleBData
2150 1, // byteIdB
2151 1 // threadIdB
2152 }));
2153}
2154
2155NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
2156 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2157 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2158
2160 // Add A, B, C operands
2161 for (Value operand : curOp.getOperandA())
2162 args.push_back(mt.lookupValue(operand));
2163 for (Value operand : curOp.getOperandB())
2164 args.push_back(mt.lookupValue(operand));
2165 for (Value operand : curOp.getOperandC())
2166 args.push_back(mt.lookupValue(operand));
2167
2168 // Add sparse metadata and selector
2169 args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
2170 args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
2171
2172 // Add scale operands
2173 args.push_back(mt.lookupValue(curOp.getScaleAData()));
2174 args.push_back(mt.lookupValue(curOp.getByteIdA()));
2175 args.push_back(mt.lookupValue(curOp.getThreadIdA()));
2176 args.push_back(mt.lookupValue(curOp.getScaleBData()));
2177 args.push_back(mt.lookupValue(curOp.getByteIdB()));
2178 args.push_back(mt.lookupValue(curOp.getThreadIdB()));
2179
2180 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2181 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2182 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2183 inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
2184 curOp.getBlockScaleFormat(), curOp.getKind());
2185
2186 return {intId, args};
2187}
2188
2189LogicalResult MmaSpBlockScaleOp::verify() {
2190 // Check that orderedMetadata is present
2191 if (!getOrderedMetadata()) {
2192 return emitOpError("'orderedMetadata' attribute is mandatory");
2193 }
2194
2195 LogicalResult result = success();
2196 int m = getShape().getM();
2197 int n = getShape().getN();
2198 int k = getShape().getK();
2199
2200 if (m == 16 && n == 8 && k == 128) {
2201 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2202 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2204 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2205 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2206 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2208 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2209 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2211 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2212 } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2213 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2214 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2215 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2216 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2217 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2218 result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
2219 "attributes for mma.m16n8k128.mxf4nvf4");
2220 } else {
2221 result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
2222 }
2223 } else if (m == 16 && n == 8 && k == 64) {
2224 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2225 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2226 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2227 result =
2228 emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
2229 "attributes for mma.m16n8k64");
2230 } else {
2231 result = emitOpError("unsupported Geom for sparse mma with block scaling");
2232 }
2233 return result;
2234}
2235
2236LogicalResult ShflOp::verify() {
2237 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2238
2239 auto verifyTypeError = [&](Twine desc, Type expectedType,
2240 Type actualType) -> LogicalResult {
2241 return emitOpError("expected " + desc + " to be of type ")
2242 << expectedType << " but got " << actualType << " instead";
2243 };
2244
2245 if (returnStructType) {
2246 if (!getReturnValueAndIsValid())
2247 return emitOpError("\"return_value_and_is_valid\" attribute must be "
2248 "specified when the return type is a struct type");
2249
2250 if (returnStructType.getBody().size() != 2)
2251 return emitOpError("expected return type to be a two-element struct");
2252
2253 llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
2254 auto resultType = returnStruct[0];
2255 if (resultType != getVal().getType())
2256 return verifyTypeError("first element in the returned struct",
2257 getVal().getType(), resultType);
2258
2259 auto predicateType = returnStruct[1];
2260 if (!predicateType.isInteger(1))
2261 return verifyTypeError("second element in the returned struct",
2262 mlir::IntegerType::get(getContext(), 1),
2263 predicateType);
2264 } else {
2265 if (getReturnValueAndIsValid())
2266 return emitOpError("expected return type to be a two-element struct");
2267
2268 if (getType() != getVal().getType())
2269 return verifyTypeError("return type", getVal().getType(), getType());
2270 }
2271 return success();
2272}
2273
2274std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
2275 NVVM::MMAFrag frag, int nRow,
2276 int nCol,
2277 MLIRContext *context) {
2278 unsigned numberElements = 0;
2279 Type elementType;
2280 OpBuilder builder(context);
2281 Type f16x2 = VectorType::get(2, builder.getF16Type());
2282 if (type == NVVM::MMATypes::f16) {
2283 elementType = f16x2;
2284 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2285 numberElements = 8;
2286 else
2287 numberElements = 4;
2288 } else if (type == NVVM::MMATypes::f32) {
2289 elementType = builder.getF32Type();
2290 numberElements = 8;
2291 } else if (type == NVVM::MMATypes::f64) {
2292 elementType = builder.getF64Type();
2293 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2294 numberElements = 1;
2295 else
2296 numberElements = 2;
2297 } else if (type == NVVM::MMATypes::tf32) {
2298 elementType = builder.getI32Type();
2299 numberElements = 4;
2300 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2301 elementType = builder.getI32Type();
2302 int parallelSize = 0;
2303 if (frag == NVVM::MMAFrag::a)
2304 parallelSize = nRow;
2305 if (frag == NVVM::MMAFrag::b)
2306 parallelSize = nCol;
2307
2308 // m == 16 && n == 16 && k == 16
2309 if (parallelSize == 16)
2310 numberElements = 2;
2311 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
2312 else if (parallelSize == 8)
2313 numberElements = 1;
2314 else if (parallelSize == 32)
2315 numberElements = 4;
2316 } else if (type == NVVM::MMATypes::s32) {
2317 elementType = builder.getI32Type();
2318 numberElements = 8;
2319 }
2320 assert(numberElements != 0 && elementType != nullptr);
2321 return std::make_pair(elementType, numberElements);
2322}
2323
2324static std::pair<mlir::Type, unsigned>
2325inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
2326 int k, MLIRContext *context) {
2327 int nRow, nCol;
2328 if (frag == NVVM::MMAFrag::a) {
2329 nRow = m;
2330 nCol = k;
2331 } else if (frag == NVVM::MMAFrag::b) {
2332 nRow = k;
2333 nCol = n;
2334 } else {
2335 nRow = m;
2336 nCol = n;
2337 }
2338 assert(nRow && nCol);
2339 return inferMMAType(type, frag, nRow, nCol, context);
2340}
2341
2342LogicalResult NVVM::WMMALoadOp::verify() {
2343 unsigned addressSpace =
2344 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2345 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2346 addressSpace != NVVMMemorySpace::Shared)
2347 return emitOpError("expected source pointer in memory "
2348 "space 0, 1, 3");
2349
2350 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2351 getEltype(), getFrag()) == 0)
2352 return emitOpError() << "invalid attribute combination";
2353 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2354 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
2355 // Special case for f64 fragments
2356 Type f64Ty = Float64Type::get(getContext());
2357 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2358 if (getType() != f64Ty)
2359 return emitOpError("expected destination type to be f64");
2360 return success();
2361 }
2362 // Everything else is a struct
2363 Type dstType = LLVM::LLVMStructType::getLiteral(
2364 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
2365 if (getType() != dstType)
2366 return emitOpError("expected destination type is a structure of ")
2367 << typeInfo.second << " elements of type " << typeInfo.first;
2368 return success();
2369}
2370
2371LogicalResult NVVM::WMMAStoreOp::verify() {
2372 unsigned addressSpace =
2373 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
2374 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2375 addressSpace != NVVMMemorySpace::Shared)
2376 return emitOpError("expected operands to be a source pointer in memory "
2377 "space 0, 1, 3");
2378
2379 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
2380 getEltype()) == 0)
2381 return emitOpError() << "invalid attribute combination";
2382 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
2383 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2384 if (getArgs().size() != typeInfo.second)
2385 return emitOpError() << "expected " << typeInfo.second << " data operands";
2386 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
2387 return operands.getType() != typeInfo.first;
2388 }))
2389 return emitOpError() << "expected data operands of type " << typeInfo.first;
2390 return success();
2391}
2392
2393LogicalResult NVVM::WMMAMmaOp::verify() {
2394 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
2395 getLayoutB(), getEltypeA(),
2396 getEltypeB()) == 0)
2397 return emitOpError() << "invalid attribute combination";
2398 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
2399 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
2400 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
2401 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
2402 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
2403 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
2404 SmallVector<Type, 32> arguments;
2405 arguments.append(typeInfoA.second, typeInfoA.first);
2406 arguments.append(typeInfoB.second, typeInfoB.first);
2407 arguments.append(typeInfoC.second, typeInfoC.first);
2408 unsigned numArgs = arguments.size();
2409 if (getArgs().size() != numArgs)
2410 return emitOpError() << "expected " << numArgs << " arguments";
2411 for (unsigned i = 0; i < numArgs; i++) {
2412 if (getArgs()[i].getType() != arguments[i])
2413 return emitOpError() << "expected argument " << i << " to be of type "
2414 << arguments[i];
2415 }
2416 Type dstType = LLVM::LLVMStructType::getLiteral(
2417 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
2418 if (getType() != dstType)
2419 return emitOpError("expected destination type is a structure of ")
2420 << typeInfoC.second << " elements of type " << typeInfoC.first;
2421 return success();
2422}
2423
2424LogicalResult NVVM::LdMatrixOp::verify() {
2425 uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
2426 if (m == 8 && n == 8) {
2427 if (num != 1 && num != 2 && num != 4) {
2428 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
2429 "matrix");
2430 }
2431 if (getEltType() != LdStMatrixEltType::B16) {
2432 return emitOpError("expected element type to be b16 for 8x8 matrix");
2433 }
2434 } else if (m == 8 && n == 16) {
2435 if (num != 1 && num != 2 && num != 4) {
2436 return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
2437 "matrix");
2438 }
2439 if (getLayout() != MMALayout::row) {
2440 return emitOpError("expected layout to be row for 8x16 matrix");
2441 }
2442 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2443 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2444 return emitOpError("expected element type to be b8x16.b4x16_p64 or "
2445 "b8x16.b6x16_p32 for 8x16 matrix");
2446 }
2447 } else if (m == 16 && n == 16) {
2448 if (num != 1 && num != 2) {
2449 return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
2450 "matrix");
2451 }
2452 if (getLayout() != MMALayout::col) {
2453 return emitOpError("expected layout to be col for 16x16 matrix");
2454 }
2455 if (getEltType() != LdStMatrixEltType::B8 &&
2456 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2457 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2458 return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
2459 "b8x16.b6x16_p32 for 16x16 matrix");
2460 }
2461 } else {
2462 return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
2463 }
2464
2465 Type i32 = IntegerType::get(getContext(), 32);
2466 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2467 if (numElements == 1 && getType() != i32)
2468 return emitOpError("expected destination type is i32");
2469 if (numElements == 2 || numElements == 4) {
2470 Type dstType = LLVM::LLVMStructType::getLiteral(
2471 getContext(), SmallVector<Type>(numElements, i32));
2472 if (getType() != dstType)
2473 return emitOpError("expected destination type is a structure of ")
2474 << numElements << " elements of type i32";
2475 }
2476
2477 return success();
2478}
2479
2480LogicalResult NVVM::StMatrixOp::verify() {
2481 int numMatrix = getSources().size();
2482 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2483 return emitOpError("expected num attribute to be 1, 2 or 4");
2484
2485 int m = getShape().getM(), n = getShape().getN();
2486 if (m == 8 && n == 8) {
2487 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2488 return emitOpError("expected element type to be B16 for 8x8 matrix");
2489 }
2490 } else if (m == 16 && n == 8) {
2491 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2492 return emitOpError("expected element type to be B8 for 16x8 matrix");
2493 }
2494 if (getLayout() != NVVM::MMALayout::col) {
2495 return emitOpError("expected layout to be col for 16x8 matrix");
2496 }
2497 } else {
2498 return emitOpError("expected shape to be 8x8 or 16x8");
2499 }
2500
2501 return success();
2502}
2503
2504static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
2505 if (typeA == NVVM::WGMMATypes::tf32)
2506 return 8;
2507 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2508 return 16;
2509 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2510 return 32;
2511 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2512 return 32;
2513 if (typeA == NVVM::WGMMATypes::b1)
2514 return 256;
2515 return failure();
2516}
2517
2518static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
2519 NVVM::WGMMATypes typeA,
2520 NVVM::WGMMATypes typeB) {
2521 switch (typeA) {
2522 case NVVM::WGMMATypes::f16:
2523 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2524 typeB == NVVM::WGMMATypes::f16)
2525 return success();
2526 break;
2527 case NVVM::WGMMATypes::tf32:
2528 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2529 return success();
2530 break;
2531 case NVVM::WGMMATypes::u8:
2532 case NVVM::WGMMATypes::s8:
2533 if (typeD == NVVM::WGMMATypes::s32 &&
2534 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2535 return success();
2536 break;
2537 case NVVM::WGMMATypes::b1:
2538 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2539 return success();
2540 break;
2541 case NVVM::WGMMATypes::bf16:
2542 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2543 typeB == NVVM::WGMMATypes::bf16)
2544 return success();
2545 break;
2546 case NVVM::WGMMATypes::e4m3:
2547 case NVVM::WGMMATypes::e5m2:
2548 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2549 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2550 return success();
2551 break;
2552 case WGMMATypes::f32:
2553 case WGMMATypes::s32:
2554 llvm_unreachable("unsupported input types");
2555 break;
2556 }
2557 return failure();
2558}
2559
2560static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
2561 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
2562 72, 80, 88, 96, 104, 112, 120, 128,
2563 136, 144, 152, 160, 168, 176, 184, 192,
2564 200, 208, 216, 224, 232, 240, 248, 256};
2565 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
2566 80, 96, 112, 128, 144, 160,
2567 176, 192, 208, 224, 240, 256};
2568 switch (typeA) {
2569 case WGMMATypes::f16:
2570 case WGMMATypes::tf32:
2571 case WGMMATypes::bf16:
2572 case WGMMATypes::e4m3:
2573 case WGMMATypes::e5m2:
2574 if (llvm::is_contained(allowedN, sizeN))
2575 return success();
2576 break;
2577 case WGMMATypes::u8:
2578 case WGMMATypes::s8:
2579 case WGMMATypes::b1:
2580 if (llvm::is_contained(allowedNshort, sizeN))
2581 return success();
2582 break;
2583 case WGMMATypes::f32:
2584 case WGMMATypes::s32:
2585 llvm_unreachable("unsupported input types");
2586 break;
2587 }
2588 return failure();
2589}
2590
2591LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2592 Value outValue = getResults();
2593 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
2594 if (!stype)
2595 return emitOpError() << "expected results to be struct";
2596 int outputSize = stype.getBody().size();
2597 WGMMATypes typeD = getTypeD();
2598 WGMMATypes typeA = getTypeA();
2599 WGMMATypes typeB = getTypeB();
2600
2601 for (Type t : stype.getBody()) {
2602 if (t != stype.getBody().front())
2603 return emitOpError()
2604 << "all elements in struct must be same type but there is " << t;
2605 }
2606
2607 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2608 typeD != WGMMATypes::s32) {
2609 return emitOpError() << "does not support the given output type " << typeD;
2610 }
2611 if (typeD == WGMMATypes::s32 &&
2612 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2613 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
2614 }
2615
2616 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
2617 return emitOpError() << typeD << " += " << typeA << " * " << typeB
2618 << ", it is not supported.";
2619 }
2620
2621 // Check M
2622 if (getShape().getM() != 64)
2623 return emitOpError() << "shape 'm' must be 64";
2624
2625 // Check K
2626 FailureOr<int> allowedK = getAllowedSizeK(typeA);
2627 if (failed(allowedK) || allowedK.value() != getShape().getK())
2628 return emitOpError() << "shape 'k' must be " << allowedK.value()
2629 << " for input type " << typeA;
2630
2631 // Check N
2632 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
2633 return emitOpError() << "has input type " << typeA << " n is set to "
2634 << getShape().getN() << ", it is not supported.";
2635 }
2636
2637 // Check transpose (only available for f16/bf16)
2638 // Matrices A should be stored in row-major and B in column-major.
2639 // Only f16/bf16 matrices can be stored in either column-major or row-major
2640 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
2641 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2642 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2643 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2644 return emitOpError()
2645 << "given layouts layout_a = " << getLayoutA()
2646 << " and layout_b = " << getLayoutB() << " for input types " << typeA
2647 << " and " << typeB
2648 << " requires transpose. However, this is only supported for: "
2649 << MMATypes::f16 << " and " << MMATypes::bf16;
2650 }
2651
2652 // Check result registers
2653 int expectedOutput = 0;
2654 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2655 expectedOutput = getShape().getN() / 2;
2656 if (typeD == WGMMATypes::f16)
2657 expectedOutput = getShape().getN() / 4;
2658 if (outputSize != expectedOutput) {
2659 return emitOpError() << "results " << expectedOutput
2660 << ", however output struct has " << outputSize
2661 << " elements";
2662 }
2663 // Check satfinite (only available for s32 accumulator)
2664 if (typeD != WGMMATypes::s32 &&
2665 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2666 NVVM::MMAIntOverflow::satfinite) {
2667 return emitOpError()
2668 << " `satfinite` can be only used with s32 accumulator, however "
2669 "the current accumulator is "
2670 << typeD;
2671 }
2672
2673 return success();
2674}
2675
2676std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2677
2678 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
2679 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2680
2681 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2682
2683 int expectedOutputRegisters = 0;
2684 if (getTypeD() == WGMMATypes::f16)
2685 expectedOutputRegisters = getShape().getN() / 4;
2686 else
2687 expectedOutputRegisters = getShape().getN() / 2;
2688
2689 std::string ptx;
2690 llvm::raw_string_ostream ss(ptx);
2691
2692 ss << "{\n"
2693 ".reg .pred p;\n"
2694 "setp.ne.b32 p, $"
2695 << ((expectedOutputRegisters * 2) + 2)
2696 << ", 0;\n"
2697 "wgmma.mma_async.sync.aligned.m"
2698 << m << "n" << n << "k" << k << "." << outputTypeName << "." << getTypeA()
2699 << "." << getTypeB();
2700 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2701 NVVM::MMAIntOverflow::satfinite)
2702 ss << ".satfinite";
2703 ss << " {";
2704 int regCnt = 0;
2705 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2706 ss << "$" << regCnt;
2707 if (regCnt != expectedOutputRegisters - 1)
2708 ss << ", ";
2709 }
2710
2711 ss << "},";
2712 // Need to map read/write registers correctly.
2713 regCnt = (regCnt * 2);
2714 ss << " $" << (regCnt) << ","
2715 << " $" << (regCnt + 1) << ","
2716 << " p";
2717 if (getTypeD() != WGMMATypes::s32) {
2718 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
2719 }
2720 // Don't add transpose parameters unless needed.
2721 if (isF16) {
2722 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
2723 }
2724 ss << ";\n"
2725 << "}\n";
2726 return ptx;
2727}
2728
2729bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2730 RewriterBase &rewriter,
2731 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
2732 &asmValues) {
2733 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2734 if (getResults())
2735 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
2736 if (getInouts())
2737 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
2738 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
2739 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
2740 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
2742 if (getTypeD() != WGMMATypes::s32) {
2743 asmValues.push_back(
2744 {makeConstantI32(rewriter,
2745 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2747 asmValues.push_back(
2748 {makeConstantI32(rewriter,
2749 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2751 }
2752 if (isF16) {
2753 asmValues.push_back(
2754 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
2756 asmValues.push_back(
2757 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
2759 }
2760 return true; // Has manual mapping
2761}
2762
2763LogicalResult NVVM::FenceProxyOp::verify() {
2764 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2765 return emitOpError() << "async_shared fence requires space attribute";
2766 }
2767 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2768 return emitOpError() << "only async_shared fence can have space attribute";
2769 }
2770 return success();
2771}
2772
2773LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2774 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2775 return emitOpError("uni-directional proxies only support generic for "
2776 "from_proxy attribute");
2777
2778 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2779 return emitOpError("uni-directional proxies only support tensormap "
2780 "for to_proxy attribute");
2781 return success();
2782}
2783
2784LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2785 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2786 return emitOpError("uni-directional proxies only support generic for "
2787 "from_proxy attribute");
2788
2789 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2790 return emitOpError("uni-directional proxies only support tensormap "
2791 "for to_proxy attribute");
2792 return success();
2793}
2794
2795LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2796 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2797 return emitOpError("only generic is support for from_proxy attribute");
2798
2799 if (getToProxy() != NVVM::ProxyKind::async)
2800 return emitOpError("only async is supported for to_proxy attribute");
2801 return success();
2802}
2803
2804LogicalResult NVVM::SetMaxRegisterOp::verify() {
2805 if (getRegCount() % 8)
2806 return emitOpError("new register size must be multiple of 8");
2807 if (getRegCount() < 24 || getRegCount() > 256)
2808 return emitOpError("new register size must be in between 24 to 256");
2809 return success();
2810}
2811
2812LogicalResult NVVM::BarrierOp::verify() {
2813 if (getNumberOfThreads() && !getBarrierId())
2814 return emitOpError(
2815 "barrier id is missing, it should be set between 0 to 15");
2816
2817 if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
2818 return emitOpError("reduction are only available when id is 0");
2819
2820 if ((getReductionOp() && !getReductionPredicate()) ||
2821 (!getReductionOp() && getReductionPredicate()))
2822 return emitOpError("reduction predicate and reduction operation must be "
2823 "specified together");
2824
2825 return success();
2826}
2827
2828LogicalResult NVVM::Tcgen05CpOp::verify() {
2829 auto mc = getMulticast();
2830
2831 using SH = Tcgen05CpShape;
2832 using MC = Tcgen05CpMulticast;
2833 switch (getShape()) {
2834 case SH::SHAPE_128x256b:
2835 case SH::SHAPE_128x128b:
2836 case SH::SHAPE_4x256b:
2837 if (mc != MC::NONE)
2838 return emitError("Invalid multicast type for tcgen05.cp Op");
2839 break;
2840 case SH::SHAPE_64x128b:
2841 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2842 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
2843 "warpx2_02_13 for tcgen05.cp Op");
2844 break;
2845 case SH::SHAPE_32x128b:
2846 if (mc != MC::WARPX4)
2847 return emitError(
2848 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2849 break;
2850 }
2851 return success();
2852}
2853
2854LogicalResult NVVM::MatchSyncOp::verify() {
2855 if (getKind() == NVVM::MatchSyncKind::all) {
2856 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
2857 if (!type || type.getBody().size() != 2 ||
2858 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2859 return emitOpError("match.sync 'all' returns a two element struct with "
2860 "first element as i32 and second element as i1");
2861 }
2862 } else {
2863 if (!getType().isInteger(32)) {
2864 return emitOpError("match.sync 'any' returns an i32");
2865 }
2866 }
2867 return success();
2868}
2869
2870LogicalResult NVVM::VoteSyncOp::verify() {
2871 if (getKind() == NVVM::VoteSyncKind::ballot) {
2872 if (!getType().isInteger(32)) {
2873 return emitOpError("vote.sync 'ballot' returns an i32");
2874 }
2875 } else {
2876 if (!getType().isInteger(1)) {
2877 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
2878 }
2879 }
2880 return success();
2881}
2882
2883LogicalResult NVVM::PrefetchOp::verify() {
2884 using MemSpace = NVVM::NVVMMemorySpace;
2885 using CacheLevel = NVVM::PrefetchCacheLevel;
2886
2887 unsigned addressSpace =
2888 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
2889 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2890 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2891
2892 if (getTensormap() && cacheLevel)
2893 return emitOpError("cannot specify both tensormap and cache level");
2894
2895 if (getTensormap()) {
2896 if (addressSpace != MemSpace::Generic &&
2897 addressSpace != MemSpace::Constant) {
2898 return emitOpError(
2899 "prefetch tensormap requires a generic or constant pointer");
2900 }
2901
2902 if (evictPriority) {
2903 return emitOpError(
2904 "prefetch tensormap does not support eviction priority");
2905 }
2906
2907 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2908 return emitOpError(
2909 "in_param_space can only be specified for a generic pointer");
2910 }
2911
2912 } else if (cacheLevel) {
2913 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
2914 addressSpace != MemSpace::Local) {
2915 return emitOpError("prefetch to cache level requires a generic, global, "
2916 "or local pointer");
2917 }
2918
2919 if (getUniform()) {
2920 if (*cacheLevel != CacheLevel::L1) {
2921 return emitOpError(
2922 "unsupported cache level, the only supported uniform "
2923 "cache level is L1");
2924 }
2925
2926 if (addressSpace != MemSpace::Generic) {
2927 return emitOpError(
2928 "prefetch to uniform cache requires a generic pointer");
2929 }
2930 }
2931
2932 if (evictPriority) {
2933 if (*cacheLevel != CacheLevel::L2)
2934 return emitOpError(
2935 "cache eviction priority supported only for cache level L2");
2936
2937 if (addressSpace != MemSpace::Global)
2938 return emitOpError("cache eviction priority requires a global pointer");
2939
2940 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
2941 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
2942 return emitOpError(
2943 "unsupported cache eviction priority, only evict_last and "
2944 "evict_normal are supported");
2945 }
2946
2947 if (getPredicate())
2948 return emitOpError("predicate supported only on prefetch tensormap");
2949
2950 } else {
2951 return emitOpError(
2952 "requires specification of either cache level or tensormap");
2953 }
2954
2955 return success();
2956}
2957
2958LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
2959 switch (getQueryType()) {
2960 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2961 if (!getType().isInteger(1))
2962 return emitOpError("is_canceled query type returns an i1");
2963 break;
2964 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2965 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2966 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2967 if (!getType().isInteger(32)) {
2968 return emitOpError("get_first_cta_id_x, get_first_cta_id_y, "
2969 "get_first_cta_id_z query types return an i32");
2970 }
2971 break;
2972 }
2973 return success();
2974}
2975
2976LogicalResult NVVM::ReduxOp::verify() {
2977 mlir::Type reduxType = getType();
2978
2979 if (!reduxType.isF32()) {
2980 if (getAbs())
2981 return emitOpError("abs attribute is supported only for f32 type");
2982 if (getNan())
2983 return emitOpError("nan attribute is supported only for f32 type");
2984 }
2985
2986 NVVM::ReductionKind kind = getKind();
2987 switch (kind) {
2988 case NVVM::ReductionKind::ADD:
2989 case NVVM::ReductionKind::AND:
2990 case NVVM::ReductionKind::OR:
2991 case NVVM::ReductionKind::XOR:
2992 case NVVM::ReductionKind::MAX:
2993 case NVVM::ReductionKind::MIN:
2994 case NVVM::ReductionKind::UMAX:
2995 case NVVM::ReductionKind::UMIN:
2996 if (!reduxType.isInteger(32))
2997 return emitOpError("'")
2998 << kind << "' reduction kind unsupported with " << reduxType
2999 << " type. Only supported type is 'i32'.";
3000 break;
3001 case NVVM::ReductionKind::FMIN:
3002 case NVVM::ReductionKind::FMAX:
3003 if (!reduxType.isF32())
3004 return emitOpError("'")
3005 << kind << "' reduction kind unsupported with " << reduxType
3006 << " type. Only supported type is 'f32'.";
3007 break;
3008 }
3009
3010 return success();
3011}
3012
3013LogicalResult NVVM::TensormapReplaceOp::verify() {
3014 auto ord = getOrd();
3015 Value newVal = getNewValue();
3016 auto newValAttr = getNewValueAttr();
3017 auto fieldName = stringifyEnum(getField());
3018
3019 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
3020 NVVM::TensormapField::GLOBAL_DIM,
3021 NVVM::TensormapField::GLOBAL_STRIDE,
3022 NVVM::TensormapField::ELEMENT_STRIDE},
3023 getField()))
3024 return emitOpError("ordinal is not supported for ")
3025 << fieldName << " field";
3026
3027 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3028 return llvm::Twine("new_value must be specified and must be an " + type +
3029 " for " + llvm::Twine(fieldName) + " field")
3030 .str();
3031 };
3032
3033 auto invalidNewValAttr = [&]() -> std::string {
3034 return (llvm::Twine(
3035 "new_value_attr must be specified and must be a valid ") +
3036 llvm::Twine(fieldName) + " attribute for " + fieldName + " field")
3037 .str();
3038 };
3039
3040 switch (getField()) {
3041 case NVVM::TensormapField::GLOBAL_ADDRESS:
3042 if (!(newVal && newVal.getType().isInteger(64)))
3043 return emitOpError(invalidNewVal("i64"));
3044 break;
3045 case NVVM::TensormapField::RANK:
3046 if (!(newVal && newVal.getType().isInteger(32)))
3047 return emitOpError(invalidNewVal("i32"));
3048 break;
3049 case NVVM::TensormapField::GLOBAL_STRIDE:
3050 if (!ord)
3051 return emitOpError("ordinal is required for global_stride field");
3052 if (!(newVal && newVal.getType().isInteger(64)))
3053 return emitOpError(invalidNewVal("i64"));
3054 break;
3055 case NVVM::TensormapField::BOX_DIM:
3056 case NVVM::TensormapField::GLOBAL_DIM:
3057 case NVVM::TensormapField::ELEMENT_STRIDE:
3058 if (!ord)
3059 return emitOpError("ordinal is required for ")
3060 << stringifyEnum(getField()) << " field";
3061 if (!(newVal && newVal.getType().isInteger(32)))
3062 return emitOpError(invalidNewVal("i32"));
3063 break;
3064 case NVVM::TensormapField::ELEMTYPE:
3065 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3066 return emitOpError(invalidNewValAttr());
3067 break;
3068 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3069 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3070 return emitOpError(invalidNewValAttr());
3071 break;
3072 case NVVM::TensormapField::SWIZZLE_MODE:
3073 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3074 return emitOpError(invalidNewValAttr());
3075 break;
3076 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3077 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3078 return emitOpError(invalidNewValAttr());
3079 break;
3080 case NVVM::TensormapField::FILL_MODE:
3081 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3082 return emitOpError(invalidNewValAttr());
3083 break;
3084 }
3085
3086 return success();
3087}
3088
3089template <typename OpType>
3090static LogicalResult verifyAddSubFOp(OpType op) {
3091 mlir::NVVM::FPRoundingMode rndMode = op.getRnd();
3092 mlir::NVVM::SaturationMode satMode = op.getSat();
3093 bool isFTZ = op.getFtz();
3094
3095 mlir::Type opType = op.getRes().getType();
3096 mlir::Type opBaseType = isa<VectorType>(opType)
3097 ? cast<VectorType>(opType).getElementType()
3098 : opType;
3099
3100 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3101 return op.emitOpError("FTZ and saturation are not supported for "
3102 "additions/subtractions involving f64 type");
3103
3104 if (opBaseType.isF16() && !(rndMode == NVVM::FPRoundingMode::RN ||
3105 rndMode == NVVM::FPRoundingMode::NONE))
3106 return op.emitOpError("only RN rounding mode is supported for f16 and "
3107 "vector<2xf16> additions/subtractions");
3108
3109 if (opBaseType.isBF16()) {
3110 if (rndMode != NVVM::FPRoundingMode::RN &&
3111 rndMode != NVVM::FPRoundingMode::NONE)
3112 return op.emitOpError("only RN rounding mode is supported for bf16 and "
3113 "vector<2xbf16> additions/subtractions");
3114 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3115 return op.emitOpError("FTZ and saturation are not supported for bf16 and "
3116 "vector<2xbf16> additions/subtractions");
3117 }
3118
3119 // FIXME: This is a temporary check disallowing lowering to add.rn.ftz.f16(x2)
3120 // PTX instructions since the corresponding LLVM intrinsic is missing. This
3121 // should be removed once the intrinsics for f16 addition (with FTZ only) are
3122 // available.
3123 if (opBaseType.isF16() && isFTZ && satMode == NVVM::SaturationMode::NONE)
3124 return op.emitOpError("FTZ with no saturation is not supported for f16 and "
3125 "vector<2xf16> additions/subtractions");
3126
3127 return success();
3128}
3129
3130LogicalResult NVVM::AddFOp::verify() { return verifyAddSubFOp<AddFOp>(*this); }
3131
3132LogicalResult NVVM::SubFOp::verify() { return verifyAddSubFOp<SubFOp>(*this); }
3133
3134LogicalResult NVVM::FmaOp::verify() {
3135 auto opType = getRes().getType();
3136 mlir::NVVM::FPRoundingMode rndMode = getRnd();
3137 mlir::NVVM::SaturationMode satMode = getSat();
3138 bool isFTZ = getFtz();
3139 bool isRelu = getRelu();
3140 bool hasOOB = getOob();
3141
3142 auto getBaseFType = [](Type type) -> Type {
3143 if (isa<VectorType>(type))
3144 return cast<VectorType>(type).getElementType();
3145 return type;
3146 };
3147
3148 auto opBaseType = getBaseFType(opType);
3149
3150 if (rndMode == NVVM::FPRoundingMode::NONE)
3151 return emitOpError("rounding mode must be specified");
3152
3153 if (isRelu && satMode == NVVM::SaturationMode::SAT)
3154 return emitOpError("relu and saturation are not supported together");
3155
3156 if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
3157 return emitOpError("oob is not supported with saturation or FTZ");
3158
3159 if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
3160 return emitOpError("relu and oob are only supported for f16 and bf16");
3161
3162 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3163 return emitOpError("FTZ and saturation are not supported for f64 type");
3164
3165 if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
3166 return emitOpError(
3167 "only RN rounding mode is supported for f16 and vector<2xf16>");
3168
3169 if (opBaseType.isBF16()) {
3170 if (rndMode != NVVM::FPRoundingMode::RN)
3171 return emitOpError(
3172 "only RN rounding mode is supported for bf16 and vector<2xbf16>");
3173 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3174 return emitOpError(
3175 "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
3176 }
3177
3178 return success();
3179}
3180
3181/// Packs the given `field` into the `result`.
3182/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
3183static llvm::Value *
3184packValInto64Bits(llvm::IRBuilderBase &builder,
3185 llvm::Value *result, // the `result` (unset bits are zero)
3186 llvm::Value *field, // `field` to pack into `result`
3187 unsigned sizeInBits, // Size of `field` in bits
3188 unsigned start) { // Starting bit within `result`
3189 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3190
3191 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3192 if (mask != 0xffffffffu)
3193 field = builder.CreateAnd(field, builder.getInt32(mask));
3194
3195 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3196 field = builder.CreateShl(field, start);
3197
3198 return builder.CreateOr(result, field);
3199}
3200
3201void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
3203 llvm::IRBuilderBase &builder) {
3204 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3205 llvm::Value *smemDesc = builder.getInt64(0);
3206
3207 smemDesc = packValInto64Bits(builder, smemDesc,
3208 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
3209 smemDesc = packValInto64Bits(
3210 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3211 smemDesc = packValInto64Bits(
3212 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3213
3214 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
3215 smemDesc = packValInto64Bits(builder, smemDesc,
3216 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
3217 smemDesc = packValInto64Bits(
3218 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3219 smemDesc = packValInto64Bits(builder, smemDesc,
3220 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
3221
3222 mt.mapValue(thisOp.getRes()) = smemDesc;
3223}
3224
3225//===----------------------------------------------------------------------===//
3226// getPtx methods
3227//===----------------------------------------------------------------------===//
3228
3229std::string NVVM::MBarrierInitOp::getPtx() {
3230 bool isShared = isPtrInSharedCTASpace(getAddr());
3231 return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
3232 : std::string("mbarrier.init.b64 [%0], %1;");
3233}
3234
3235std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3236 bool isShared = isPtrInSharedCTASpace(getAddr());
3237 return isShared
3238 ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3239 : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3240}
3241
3242std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3243 bool isShared = isPtrInSharedCTASpace(getAddr());
3244 llvm::StringRef space = isShared ? ".shared" : "";
3245
3246 return llvm::formatv("{\n\t"
3247 ".reg .pred P1; \n\t"
3248 "LAB_WAIT: \n\t"
3249 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3250 "@P1 bra.uni DONE; \n\t"
3251 "bra.uni LAB_WAIT; \n\t"
3252 "DONE: \n\t"
3253 "}",
3254 space);
3255}
3256
3257//===----------------------------------------------------------------------===//
3258// Canonicalization patterns
3259//===----------------------------------------------------------------------===//
3260
3263
3264 LogicalResult matchAndRewrite(SubFOp op,
3265 PatternRewriter &rewriter) const override {
3266 Location loc = op.getLoc();
3267 Value negRhs =
3268 LLVM::FNegOp::create(rewriter, loc, op.getRhs().getType(), op.getRhs());
3269
3270 rewriter.replaceOpWithNewOp<AddFOp>(op, op.getType(), op.getLhs(), negRhs,
3271 op.getRnd(), op.getSat(), op.getFtz());
3272 return success();
3273 }
3274};
3275
3276void SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3277 MLIRContext *context) {
3278 patterns.add<ConvertFsubToFnegFadd>(context);
3279}
3280
3281//===----------------------------------------------------------------------===//
3282// getIntrinsicID/getIntrinsicIDAndArgs methods
3283//===----------------------------------------------------------------------===//
3284
3285mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
3286 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3287 auto thisOp = cast<NVVM::BarrierOp>(op);
3288 llvm::Value *barrierId = thisOp.getBarrierId()
3289 ? mt.lookupValue(thisOp.getBarrierId())
3290 : builder.getInt32(0);
3291 llvm::Intrinsic::ID id;
3292 llvm::SmallVector<llvm::Value *> args = {barrierId};
3293 if (thisOp.getNumberOfThreads()) {
3294 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3295 args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
3296 } else if (thisOp.getReductionOp()) {
3297 switch (*thisOp.getReductionOp()) {
3298 case NVVM::BarrierReduction::AND:
3299 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3300 break;
3301 case NVVM::BarrierReduction::OR:
3302 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3303 break;
3304 case NVVM::BarrierReduction::POPC:
3305 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3306 break;
3307 }
3308 args.push_back(builder.CreateICmpNE(
3309 mt.lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3310 } else {
3311 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3312 }
3313
3314 return {id, std::move(args)};
3315}
3316
3318PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3319 llvm::IRBuilderBase &builder) {
3320 auto thisOp = cast<NVVM::PMEventOp>(op);
3321 llvm::Type *i16Ty = llvm::Type::getInt16Ty(mt.getLLVMContext());
3322
3323 // With event-id, mask is generated as (1 << event-id)
3324 llvm::Value *maskVal;
3325 if (auto eventAttr = thisOp.getEventIdAttr()) {
3326 uint16_t mask = static_cast<uint16_t>(1u << eventAttr.getInt());
3327 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3328 } else {
3329 maskVal =
3330 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3331 }
3332
3333 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3334}
3335
3336mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
3337 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3338 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3339 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3340 llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3341 : llvm::Intrinsic::nvvm_mbarrier_init;
3342
3343 // Fill the Intrinsic Args
3345 args.push_back(mt.lookupValue(thisOp.getAddr()));
3346 args.push_back(mt.lookupValue(thisOp.getCount()));
3347
3348 return {id, std::move(args)};
3349}
3350
3351mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
3352 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3353 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3354 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3355 llvm::Intrinsic::ID id = isShared
3356 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3357 : llvm::Intrinsic::nvvm_mbarrier_inval;
3358
3359 return {id, {mt.lookupValue(thisOp.getAddr())}};
3360}
3361
3362mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
3363 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3364 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3365
3366 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3367 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3368 // bit-0: Space
3369 // bit-1: Scope
3370 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3371
3372 static constexpr llvm::Intrinsic::ID IDs[] = {
3373 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3374 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3375 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3376 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3377
3378 // Fill the Intrinsic Args
3380 args.push_back(mt.lookupValue(thisOp.getAddr()));
3381 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3382
3383 return {IDs[index], std::move(args)};
3384}
3385
3386mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
3387 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3388 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3389
3390 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3391 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3392 // bit-0: Space
3393 // bit-1: Scope
3394 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3395
3396 static constexpr llvm::Intrinsic::ID IDs[] = {
3397 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3398 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3399 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3400 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3401
3402 // Fill the Intrinsic Args
3404 args.push_back(mt.lookupValue(thisOp.getAddr()));
3405 args.push_back(mt.lookupValue(thisOp.getTxcount()));
3406
3407 return {IDs[index], std::move(args)};
3408}
3409
3410mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
3411 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3412 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3413
3414 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3415 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3416 // bit-0: Space
3417 // bit-1: Scope
3418 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3419
3420 static constexpr llvm::Intrinsic::ID IDs[] = {
3421 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3422 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3423 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3424 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3425 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3426 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3427 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3428 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3429 llvm::Intrinsic::
3430 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3431 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3432
3433 // Tidy-up the Intrinsic Args
3434 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3435 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3436 if (needCast)
3437 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3438
3439 // We have the most basic mbarrier.arrive supported on sm_80.
3440 // It supports: Space=cta, scope=cta, No relaxed, No explicit count.
3441 // So, only for this combination use the legacy intrinsic.
3442 bool hasCount = static_cast<bool>(thisOp.getCount());
3443 if (!hasCount &&
3444 (id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3445 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3446
3447 // When count is not explicitly specified, the default is 1.
3448 llvm::LLVMContext &ctx = mt.getLLVMContext();
3449 llvm::Value *count =
3450 hasCount ? mt.lookupValue(thisOp.getCount())
3451 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3452 return {id, {mbar, count}};
3453}
3454
3455mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
3456 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3457 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3458
3459 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3460 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3461 // bit-0: Space
3462 // bit-1: Scope
3463 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3464
3465 static constexpr llvm::Intrinsic::ID IDs[] = {
3466 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3467 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3468 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3469 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3470 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3471 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3472 llvm::Intrinsic::
3473 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3474 llvm::Intrinsic::
3475 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3476 llvm::Intrinsic::
3477 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3478 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3479
3480 // Tidy-up the Intrinsic Args
3481 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3482 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3483 if (needCast)
3484 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3485
3486 // When count is not explicitly specified, the default is 1.
3487 llvm::LLVMContext &ctx = mt.getLLVMContext();
3488 bool hasCount = static_cast<bool>(thisOp.getCount());
3489 llvm::Value *count =
3490 hasCount ? mt.lookupValue(thisOp.getCount())
3491 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3492
3493 return {id, {mbar, count}};
3494}
3495
3496bool MBarrierArriveExpectTxOp::getAsmValues(
3497 RewriterBase &rewriter,
3498 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3499 &asmValues) {
3500 // Add all the operands but not the attrs to the asmValues list.
3501 // The attrs here are used to generate the right variants for
3502 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3503 for (auto val : getOperands())
3504 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3505
3506 return false;
3507}
3508
3509mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
3510 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3511 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3512
3513 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3514 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3515 // bit-0: Space
3516 // bit-1: Scope
3517 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3518
3519 // clang-format off
3520 static constexpr llvm::Intrinsic::ID IDs[] = {
3521 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3522 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3523 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3524 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3525 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3526 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3527 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3528 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3529 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3530 // clang-format on
3531 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3532
3533 // Tidy-up the Intrinsic Args
3534 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3535 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3536 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3537 if (needCast)
3538 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3539
3540 return {id, {mbar, txcount}};
3541}
3542
3543mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
3544 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3545 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3546
3547 bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
3548 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3549 // bit-0: Space
3550 // bit-1: Scope
3551 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3552
3553 // clang-format off
3554 static constexpr llvm::Intrinsic::ID IDs[] = {
3555 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3556 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3557 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3558 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3559 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3560 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3561 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3562 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3563 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3564 // clang-format on
3565 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3566
3567 // Tidy-up the Intrinsic Args
3568 llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
3569 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3570 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3571 if (needCast)
3572 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3573
3574 return {id, {mbar, txcount}};
3575}
3576
3577mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
3578 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3579 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3580 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3581 llvm::Intrinsic::ID id =
3582 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3583 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3584 // Fill the Intrinsic Args
3586 args.push_back(mt.lookupValue(thisOp.getAddr()));
3587 args.push_back(mt.lookupValue(thisOp.getCount()));
3588
3589 return {id, std::move(args)};
3590}
3591
3592mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
3593 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3594 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3595 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3596 llvm::Intrinsic::ID id =
3597 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3598 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3599 // Fill the Intrinsic Args
3601 args.push_back(mt.lookupValue(thisOp.getAddr()));
3602 args.push_back(mt.lookupValue(thisOp.getCount()));
3603
3604 return {id, std::move(args)};
3605}
3606
3607mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
3608 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3609 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3610 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3611 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3612 // bit-0: isPhaseParity
3613 // bit-1: Scope
3614 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3615
3616 // clang-format off
3617 static constexpr llvm::Intrinsic::ID IDs[] = {
3618 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3619 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3620 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3621 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3622 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3623 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3624 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3625 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3626 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3627 // clang-format on
3628 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3629
3630 // Tidy-up the Intrinsic Args
3631 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3632 llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
3633 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3634 if (needCast)
3635 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3636
3637 return {id, {mbar, input}};
3638}
3639
3640mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
3641 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3642 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3643 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3644 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3645 bool hasTicks = static_cast<bool>(thisOp.getTicks());
3646 // bit-0: isPhaseParity
3647 // bit-1: Scope
3648 // bit-2: hasTicks
3649 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3650 (isPhaseParity ? 1 : 0);
3651
3652 // clang-format off
3653 static constexpr llvm::Intrinsic::ID IDs[] = {
3654 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3655 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3656 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3657 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3658 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3659 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3660 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3661 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3662 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3663 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3664 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3665 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3666 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3667 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3668 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3669 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3670 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3671 // clang-format on
3672 auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
3673
3674 // Tidy-up the mbarrier pointer
3675 llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
3676 bool needCast = isPtrInGenericSpace(thisOp.getAddr());
3677 if (needCast)
3678 mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
3679
3680 // Fill the Intrinsic Args
3682 args.push_back(mbar);
3683 args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
3684 if (hasTicks)
3685 args.push_back(mt.lookupValue(thisOp.getTicks()));
3686
3687 return {id, std::move(args)};
3688}
3689
3690mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
3691 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3692 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3693 bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
3694
3695 llvm::Intrinsic::ID id;
3696 if (thisOp.getNoinc()) {
3697 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3698 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3699 } else {
3700 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3701 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3702 }
3703
3704 return {id, {mt.lookupValue(thisOp.getAddr())}};
3705}
3706
3707#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3708 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3709
3710#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3711 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
3712
3713llvm::Intrinsic::ID
3714CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3716 llvm::Intrinsic::ID id;
3717
3718 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
3719 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
3720 switch (cpAsyncOp.getSize()) {
3721 case 4:
3722 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
3723 break;
3724 case 8:
3725 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
3726 break;
3727 case 16:
3728 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
3729 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
3730 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
3731 break;
3732 default:
3733 llvm_unreachable("Invalid copy size in CpAsyncOp.");
3734 }
3735
3736 // Fill the Intrinsic Args
3737 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
3738 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
3739 if (hasCpSize)
3740 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
3741
3742 return id;
3743}
3744
3745mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
3746 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3747 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
3749 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
3750
3751 // Fill the Intrinsic Args
3752 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3753 args.push_back(mt.lookupValue(thisOp.getSize()));
3754
3755 mlir::Value cacheHint = thisOp.getL2CacheHint();
3756 const bool hasCacheHint = static_cast<bool>(cacheHint);
3757 llvm::Value *i64Unused =
3758 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3759 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3760 args.push_back(builder.getInt1(hasCacheHint));
3761
3762 return {id, std::move(args)};
3763}
3764
3765mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3766 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3767 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
3769
3770 // Fill the Intrinsic Args: dst, mbar, src, size.
3771 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3772 args.push_back(mt.lookupValue(thisOp.getMbar()));
3773 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3774 args.push_back(mt.lookupValue(thisOp.getSize()));
3775
3776 // Multicast mask for shared::cluster only, if available.
3777 mlir::Value multicastMask = thisOp.getMulticastMask();
3778 const bool hasMulticastMask = static_cast<bool>(multicastMask);
3779 const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
3780 if (!isSharedCTA) {
3781 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
3782 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
3783 : i16Unused);
3784 }
3785
3786 // Cache hint, if available.
3787 mlir::Value cacheHint = thisOp.getL2CacheHint();
3788 const bool hasCacheHint = static_cast<bool>(cacheHint);
3789 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
3790 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3791
3792 // Flag arguments for multicast and cachehint.
3793 if (!isSharedCTA)
3794 args.push_back(builder.getInt1(hasMulticastMask));
3795 args.push_back(builder.getInt1(hasCacheHint));
3796
3797 llvm::Intrinsic::ID id =
3798 isSharedCTA
3799 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
3800 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
3801
3802 return {id, std::move(args)};
3803}
3804
3805mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3806 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3807 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
3809 llvm::Intrinsic::ID id =
3810 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
3811
3812 // Fill the Intrinsic Args
3813 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3814 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
3815 args.push_back(mt.lookupValue(thisOp.getSize()));
3816
3817 mlir::Value cacheHint = thisOp.getL2CacheHint();
3818 const bool hasCacheHint = static_cast<bool>(cacheHint);
3819 llvm::Value *i64Unused =
3820 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3821 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3822 args.push_back(builder.getInt1(hasCacheHint));
3823
3824 // Choose the bytemask variant
3825 if (mlir::Value byteMask = thisOp.getByteMask()) {
3826 args.push_back(mt.lookupValue(byteMask));
3827 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
3828 }
3829
3830 return {id, std::move(args)};
3831}
3832
3833bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
3834 RewriterBase &rewriter,
3835 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
3836 &asmValues) {
3837 // Add all the operands but not the attrs to the asmValues list.
3838 // The attrs here are used to generate the right variants for
3839 // intrinsics-lowering. So, we ignore them while generating inline-PTX.
3840 for (auto val : getOperands())
3841 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
3842
3843 return false;
3844}
3845
3847CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3848 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3849 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
3850 const bool isCTAOnly = thisOp.getIsCTAOnly();
3852
3853 // Fill the Intrinsic Args
3854 args.push_back(mt.lookupValue(thisOp.getDstMem()));
3855 args.push_back(mt.lookupValue(thisOp.getMbar()));
3856 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3857
3858 // Coordinates and im2col-offsets
3859 for (mlir::Value v : thisOp.getCoordinates())
3860 args.push_back(mt.lookupValue(v));
3861 for (mlir::Value v : thisOp.getIm2colOffsets())
3862 args.push_back(mt.lookupValue(v));
3863
3864 // MulticastMask, if available
3865 mlir::Value mcMask = thisOp.getMulticastMask();
3866 const bool hasMC = static_cast<bool>(mcMask);
3867 llvm::Value *i16Zero =
3868 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
3869
3870 // CacheHint, if available
3871 mlir::Value cacheHint = thisOp.getL2CacheHint();
3872 const bool hasCacheHint = static_cast<bool>(cacheHint);
3873 llvm::Value *i64Zero =
3874 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3875
3876 // Flag argument CTAGroup
3877 // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
3878 // Hence, the +1 to getGroup().
3879 const int32_t val =
3880 thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
3881 llvm::Value *cg =
3882 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
3883
3884 if (!isCTAOnly) {
3885 // For shared::cluster, all the arguments that we build are applicable.
3886 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
3887 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
3888 args.push_back(builder.getInt1(hasMC));
3889 args.push_back(builder.getInt1(hasCacheHint));
3890 args.push_back(cg);
3891 } else {
3892 // For shared::cta, only cache-hint is applicable.
3893 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
3894 args.push_back(builder.getInt1(hasCacheHint));
3895 }
3896
3897 constexpr size_t numDims = 5; // 1D to 5D
3898 constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
3899 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
3900 using TableTy = std::array<rowTy, numModes>;
3901 static constexpr TableTy IDTable{
3902 {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
3903 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
3904 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
3905 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
3906 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
3908 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
3909 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
3910 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
3912 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
3913 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
3914 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
3916 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
3917 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
3918 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
3920 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
3921
3922 static constexpr TableTy IDTableCTA{
3923 {{notIntrinsic,
3924 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
3925 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
3926 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
3927 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
3928 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
3930 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
3931 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
3932 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
3934 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
3935 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
3936 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
3938 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
3939 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
3940 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
3942 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
3943
3944 static_assert(
3945 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
3946 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
3947 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
3948 size_t mode = static_cast<size_t>(thisOp.getMode());
3949 size_t dim = thisOp.getCoordinates().size();
3950 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
3951 assert(id != notIntrinsic &&
3952 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
3953
3954 return {id, std::move(args)};
3955}
3956
3957mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
3958 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
3959 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
3961
3962 // Fill the Intrinsic Args
3963 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
3964
3965 for (auto v : thisOp.getCoordinates())
3966 args.push_back(mt.lookupValue(v));
3967 for (auto v : thisOp.getIm2colOffsets())
3968 args.push_back(mt.lookupValue(v));
3969
3970 mlir::Value cacheHint = thisOp.getL2CacheHint();
3971 const bool hasCacheHint = static_cast<bool>(cacheHint);
3972 llvm::Value *i64Unused =
3973 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
3974 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
3975 args.push_back(builder.getInt1(hasCacheHint));
3976
3977 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3978 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3979 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
3980 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
3981 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
3982 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
3983 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
3984 {NI, NI, NI,
3985 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
3986 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
3987 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
3988 {NI, NI, NI,
3989 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
3990 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
3991 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
3992 {NI, NI, NI,
3993 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
3994 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
3995 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
3996 {NI, NI, NI, NI, NI,
3997 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
3998
3999 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
4000 "TMALoadModes must match number of rows in IDTable");
4001 size_t mode = static_cast<size_t>(thisOp.getMode());
4002 size_t dim = thisOp.getCoordinates().size();
4003 llvm::Intrinsic::ID id = IDTable[mode][dim];
4004 if (id == llvm::Intrinsic::not_intrinsic)
4005 llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
4006
4007 return {id, std::move(args)};
4008}
4009
4011CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
4012 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4013 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
4015
4016 // Fill the Intrinsic Args
4017 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4018 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4019
4020 for (auto v : thisOp.getCoordinates())
4021 args.push_back(mt.lookupValue(v));
4022
4023 mlir::Value cacheHint = thisOp.getL2CacheHint();
4024 const bool hasCacheHint = static_cast<bool>(cacheHint);
4025 llvm::Value *i64Unused =
4026 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
4027 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
4028 args.push_back(builder.getInt1(hasCacheHint));
4029
4030 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4031 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4032 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
4033 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
4034 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
4035 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
4036 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
4037 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
4038 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
4039 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
4040 {NI, NI, NI, NI, NI,
4041 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
4042
4043 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
4044 "TMAStoreModes must match number of rows in IDTable");
4045 size_t mode = static_cast<size_t>(thisOp.getMode());
4046 size_t dim = thisOp.getCoordinates().size();
4047 llvm::Intrinsic::ID id = IDTable[mode][dim];
4048 if (id == llvm::Intrinsic::not_intrinsic)
4049 llvm_unreachable(
4050 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
4051
4052 return {id, std::move(args)};
4053}
4054
4055NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
4056 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4057 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
4058 llvm::LLVMContext &ctx = mt.getLLVMContext();
4059
4061
4062 // Arguments to the intrinsic:
4063 // shared_mem_ptr, tmaDesc, tensorDims
4064 // cache_hint(if applicable) and flag(boolean)
4065 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
4066 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
4067
4068 for (Value v : thisOp.getCoordinates())
4069 args.push_back(mt.lookupValue(v));
4070
4071 mlir::Value cacheHint = thisOp.getL2CacheHint();
4072 const bool hasCacheHint = static_cast<bool>(cacheHint);
4073 llvm::Value *i64ZeroValue =
4074 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
4075 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
4076 args.push_back(builder.getInt1(hasCacheHint));
4077
4078 const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
4079
4080 constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
4081 constexpr unsigned numLayouts = 2; // TILE, IM2COL
4082 constexpr unsigned maxDim = 5; // 1D to 5D
4083 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
4084 using layoutTable = std::array<row, numLayouts>;
4085 using fullTable = std::array<layoutTable, numRedKinds>;
4086 static constexpr fullTable IDTable{
4087 {// RedTy::ADD
4088 {{{{notIntrinsic,
4089 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
4090 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
4091 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
4092 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
4093 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
4095 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
4096 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
4097 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
4098 // RedTy::MIN
4099 {{{{notIntrinsic,
4100 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
4101 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
4102 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
4103 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
4104 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
4106 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
4107 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
4108 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
4109 // RedTy::MAX
4110 {{{{notIntrinsic,
4111 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
4112 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
4113 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
4114 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
4115 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
4117 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
4118 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
4119 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
4120 // RedTy::INC
4121 {{{{notIntrinsic,
4122 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
4123 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
4124 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
4125 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
4126 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
4128 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
4129 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4130 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4131 // RedTy::DEC
4132 {{{{notIntrinsic,
4133 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4134 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4135 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4136 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4137 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4139 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4140 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4141 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4142 // RedTy::AND
4143 {{{{notIntrinsic,
4144 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4145 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4146 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4147 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4148 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4150 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4151 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4152 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4153 // RedTy::OR
4154 {{{{notIntrinsic,
4155 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4156 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4157 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4158 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4159 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4161 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4162 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4163 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4164 // RedTy::XOR
4165 {{{{notIntrinsic,
4166 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4167 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4168 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4169 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4170 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4172 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4173 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4174 llvm::Intrinsic::
4175 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4176
4177 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4178 "TMAReduxKinds must match number of rows in IDTable");
4179
4180 size_t redKind = static_cast<size_t>(thisOp.getRedKind());
4181 size_t mode = static_cast<size_t>(thisOp.getMode());
4182 size_t dim = thisOp.getCoordinates().size();
4183
4184 assert(redKind < IDTable.size() &&
4185 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4186 assert(mode < IDTable[redKind].size() &&
4187 "Invalid mode for CpAsyncBulkTensorReduceOp");
4188 assert(dim < IDTable[redKind][mode].size() &&
4189 "Invalid dim for CpAsyncBulkTensorReduceOp");
4190
4191 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4192
4193 assert(intrinsicID != notIntrinsic &&
4194 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4195
4196 return {intrinsicID, std::move(args)};
4197}
4198
4199#define _none
4200
4201#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4202 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4203 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4204
4205#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4206 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4207 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4208
4209llvm::Intrinsic::ID
4210ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4211 NVVM::SaturationMode sat, bool hasRelu) {
4212 using RndMode = NVVM::FPRoundingMode;
4213 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4214 switch (rnd) {
4215 case RndMode::RN:
4216 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
4217 case RndMode::RZ:
4218 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
4219 case RndMode::RNA:
4220 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
4221 default:
4222 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
4223 }
4224}
4225
4227ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4229 llvm::IRBuilderBase &builder) {
4231 args.push_back(mt.lookupValue(op.getA()));
4232 args.push_back(mt.lookupValue(op.getB()));
4233
4234 bool hasRelu = op.getRelu();
4235
4236 llvm::Intrinsic::ID intId =
4237 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4238 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4239
4240 return {intId, std::move(args)};
4241}
4242
4243#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4244 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4245 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4246
4247llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4248 bool hasRelu) {
4250 .Case([&](mlir::Float6E2M3FNType) {
4251 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
4252 })
4253 .Case([&](mlir::Float6E3M2FNType) {
4254 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
4255 })
4256 .Default([](mlir::Type) {
4257 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
4258 return llvm::Intrinsic::not_intrinsic;
4259 });
4260}
4261
4263ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
4265 llvm::IRBuilderBase &builder) {
4266 mlir::Type dstTy = op.getDstTy();
4267 bool hasRelu = op.getRelu();
4268
4269 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4270
4271 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4272 intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
4273 : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
4274
4276 args.push_back(mt.lookupValue(op.getSrc()));
4277
4278 return {intId, std::move(args)};
4279}
4280
4282ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
4284 llvm::IRBuilderBase &builder) {
4285 mlir::Type dstTy = op.getDstTy();
4286 bool hasRelu = op.getRelu();
4287
4288 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4289
4290 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4291 intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
4292 : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
4293
4295 args.push_back(mt.lookupValue(op.getSrc()));
4296
4297 return {intId, std::move(args)};
4298}
4299
4300llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4301 bool hasRelu) {
4303 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4304 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
4305 : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
4306 })
4307 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4308 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
4309 : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
4310 })
4311 .Default([](mlir::Type) {
4312 llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
4313 return llvm::Intrinsic::not_intrinsic;
4314 });
4315}
4316
4317llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
4318 bool hasRelu) {
4320 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4321 return hasRelu
4322 ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
4323 : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
4324 })
4325 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4326 return hasRelu
4327 ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
4328 : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
4329 })
4330 .Default([](mlir::Type) {
4331 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
4332 return llvm::Intrinsic::not_intrinsic;
4333 });
4334}
4335
4336#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4337 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4338 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4339
4340#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4341 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4342 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4343
4344llvm::Intrinsic::ID
4345ConvertF32x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4346 NVVM::SaturationMode sat, bool hasRelu) {
4347 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4348 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4349 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4350
4352 .Case([&](mlir::Float8E4M3FNType) {
4353 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
4354 })
4355 .Case([&](mlir::Float8E5M2Type) {
4356 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
4357 })
4358 .Case([&](mlir::Float8E8M0FNUType) {
4359 if (hasRoundingModeRZ)
4360 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
4361 else if (hasRoundingModeRP)
4362 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
4363
4364 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4365 })
4366 .Default([](mlir::Type) {
4367 llvm_unreachable("Invalid conversion in ConvertF32x2ToF8x2Op");
4368 return llvm::Intrinsic::not_intrinsic;
4369 });
4370}
4371
4372#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4373 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4374 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4375
4376llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4377 bool hasRelu) {
4379 .Case([&](mlir::Float8E4M3FNType) {
4380 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
4381 })
4382 .Case([&](mlir::Float8E5M2Type) {
4383 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
4384 })
4385 .Default([](mlir::Type) {
4386 llvm_unreachable("Invalid conversion in ConvertF16x2ToF8x2Op");
4387 return llvm::Intrinsic::not_intrinsic;
4388 });
4389}
4390
4391llvm::Intrinsic::ID
4392ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
4393 NVVM::FPRoundingMode rnd,
4394 NVVM::SaturationMode sat, bool hasRelu) {
4395 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4396
4397 static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
4398 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
4399 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
4400 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
4401 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
4402 };
4403
4405 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4406 return hasRelu
4407 ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
4408 : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
4409 })
4410 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4411 return hasRelu
4412 ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
4413 : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
4414 })
4415 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
4416 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4417 unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
4418 return ue8m0x2IDs[index];
4419 })
4420 .Default([](mlir::Type) {
4421 llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
4422 return llvm::Intrinsic::not_intrinsic;
4423 });
4424}
4425
4426NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
4427 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4428 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4429
4430 bool hasRelu = curOp.getRelu();
4431
4432 llvm::Intrinsic::ID intId =
4434 .Case([&](Float8E4M3FNType type) {
4435 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4436 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4437 })
4438 .Case([&](Float8E5M2Type type) {
4439 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4440 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4441 })
4442 .Default([](mlir::Type type) {
4443 llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
4444 return llvm::Intrinsic::not_intrinsic;
4445 });
4446
4447 llvm::Value *packedI16 =
4448 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4449 llvm::Type::getInt16Ty(builder.getContext()));
4450
4451 return {intId, {packedI16}};
4452}
4453
4454NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
4455 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4456 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4457
4458 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4459 llvm::Value *packedI16 =
4460 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4461 llvm::Type::getInt16Ty(builder.getContext()));
4462
4463 return {intId, {packedI16}};
4464}
4465
4466NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
4467 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4468 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4469
4470 bool hasRelu = curOp.getRelu();
4471
4472 llvm::Intrinsic::ID intId =
4474 .Case([&](Float6E2M3FNType type) {
4475 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4476 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4477 })
4478 .Case([&](Float6E3M2FNType type) {
4479 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4480 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4481 })
4482 .Default([](mlir::Type type) {
4483 llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
4484 return llvm::Intrinsic::not_intrinsic;
4485 });
4486
4487 llvm::Value *packedI16 =
4488 builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
4489 llvm::Type::getInt16Ty(builder.getContext()));
4490
4491 return {intId, {packedI16}};
4492}
4493
4494NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
4495 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4496 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4497
4498 bool hasRelu = curOp.getRelu();
4499
4500 llvm::Intrinsic::ID intId =
4502 .Case([&](Float4E2M1FNType type) {
4503 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4504 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4505 })
4506 .Default([](mlir::Type type) {
4507 llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
4508 return llvm::Intrinsic::not_intrinsic;
4509 });
4510
4511 llvm::Value *extendedI16 =
4512 builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
4513 llvm::Type::getInt16Ty(builder.getContext()));
4514
4515 return {intId, {extendedI16}};
4516}
4517
4518NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
4519 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4520 auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
4521 bool hasRelu = thisOp.getRelu();
4522 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
4523
4524 llvm::Intrinsic::ID id =
4525 hasRelu
4526 ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4527 : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4528
4529 // Fill the Intrinsic Args
4531 args.push_back(mt.lookupValue(thisOp.getA()));
4532 args.push_back(mt.lookupValue(thisOp.getB()));
4533 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
4534 : builder.getInt16(0x7f7f));
4535 return {id, std::move(args)};
4536}
4537
4538NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(
4539 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4540 auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
4541 bool hasRelu = thisOp.getRelu();
4542 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
4543
4544 llvm::Intrinsic::ID id =
4545 hasRelu
4546 ? llvm::Intrinsic::
4547 nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4548 : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4549
4550 // Fill the Intrinsic Args
4552 args.push_back(mt.lookupValue(thisOp.getSrc()));
4553 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
4554 : builder.getInt16(0x7f7f));
4555 return {id, std::move(args)};
4556}
4557
4558NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(
4559 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4560 auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
4561 bool hasRelu = thisOp.getRelu();
4562 bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
4563 bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
4564
4565 static constexpr llvm::Intrinsic::ID ids[] = {
4566 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
4567 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4568 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4569 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4570 };
4571
4572 unsigned idx = (hasSat << 1) | hasRelu;
4573
4574 // Fill the Intrinsic Args
4576 llvm::Value *packedI16 =
4577 builder.CreateBitCast(mt.lookupValue(thisOp.getSrc()),
4578 llvm::Type::getInt16Ty(builder.getContext()));
4579 args.push_back(packedI16);
4580 args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
4581 : builder.getInt16(0x7f7f));
4582
4583 return {ids[idx], std::move(args)};
4584}
4585
4586llvm::Intrinsic::ID
4587Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
4590 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4591 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4592 .getAddressSpace();
4593 bool isShared = as == NVVMMemorySpace::Shared;
4594 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4595
4596 llvm::Intrinsic::ID id;
4597 if (isShared) {
4598 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4599 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4600 } else {
4601 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4602 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4603 }
4604
4605 // Fill the Intrinsic Args
4606 args.push_back(mt.lookupValue(curOp.getAddr()));
4607 args.push_back(mt.lookupValue(curOp.getNCols()));
4608
4609 return id;
4610}
4611
4612llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4615 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4616 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4617 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4618 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4619
4620 // Fill the Intrinsic Args
4621 args.push_back(mt.lookupValue(curOp.getTaddr()));
4622 args.push_back(mt.lookupValue(curOp.getNCols()));
4623
4624 return id;
4625}
4626
4627#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4628 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4629 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4630
4631#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4632 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4633 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4634
4635llvm::Intrinsic::ID
4636Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
4639 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4640 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4641 .getAddressSpace();
4642 bool isShared = as == NVVMMemorySpace::Shared;
4643 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
4644 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4645
4646 llvm::Intrinsic::ID id =
4647 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
4648 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
4649
4650 // Fill the Intrinsic Args
4651 args.push_back(mt.lookupValue(curOp.getAddr()));
4652 if (hasMulticast)
4653 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
4654
4655 return id;
4656}
4657
4658#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4659 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4660
4661#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4662 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4663 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4664
4665#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4666 [&]() -> auto { \
4667 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4668 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4669 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4670 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4671 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4672 }()
4673
4675ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4677 llvm::IRBuilderBase &builder) {
4678 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4679 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4680 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4681 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4682 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4683 };
4684 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4685 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4686 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4687 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4688 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4689 };
4690 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4691 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4692 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4693 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4694 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4695 };
4696
4697 unsigned hasRelu = op.getRelu() ? 1 : 0;
4698 unsigned hasSatFinite =
4699 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4700 // idx: bit-0 - relu
4701 // bit-1 - satfinite
4702 unsigned idx = (hasSatFinite << 1) | hasRelu;
4703
4705 args.push_back(mt.lookupValue(op.getSrcHi()));
4706 args.push_back(mt.lookupValue(op.getSrcLo()));
4707 if (op.getRandomBits())
4708 args.push_back(mt.lookupValue(op.getRandomBits()));
4709
4710 switch (op.getRnd()) {
4711 case FPRoundingMode::RN:
4712 return {rndRNIds[idx], std::move(args)};
4713 case FPRoundingMode::RZ:
4714 return {rndRZIds[idx], std::move(args)};
4715 case FPRoundingMode::RS:
4716 return {rndRSIds[idx], std::move(args)};
4717 default:
4718 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
4719 }
4720}
4721
4723ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
4725 llvm::IRBuilderBase &builder) {
4726 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4727 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
4728 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
4729 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
4730 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
4731 };
4732 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4733 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
4734 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
4735 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
4736 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
4737 };
4738 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4739 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
4740 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
4741 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
4742 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
4743 };
4744
4745 unsigned hasRelu = op.getRelu() ? 1 : 0;
4746 unsigned hasSatFinite =
4747 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4748 // idx: bit-0 - relu
4749 // bit-1 - satfinite
4750 unsigned idx = (hasSatFinite << 1) | hasRelu;
4751
4753 args.push_back(mt.lookupValue(op.getSrcHi()));
4754 args.push_back(mt.lookupValue(op.getSrcLo()));
4755 if (op.getRandomBits())
4756 args.push_back(mt.lookupValue(op.getRandomBits()));
4757
4758 switch (op.getRnd()) {
4759 case FPRoundingMode::RN:
4760 return {rndRNIds[idx], std::move(args)};
4761 case FPRoundingMode::RZ:
4762 return {rndRZIds[idx], std::move(args)};
4763 case FPRoundingMode::RS:
4764 return {rndRSIds[idx], std::move(args)};
4765 default:
4766 llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
4767 }
4768}
4769
4770llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
4771 mlir::Type dstTy = getDstTy();
4772 bool hasRelu = getRelu();
4773
4775 .Case([&](mlir::Float8E4M3FNType) {
4776 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
4777 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
4778 })
4779 .Case([&](mlir::Float8E5M2Type) {
4780 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
4781 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
4782 })
4783 .Default([](mlir::Type) {
4784 llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
4785 return llvm::Intrinsic::not_intrinsic;
4786 });
4787}
4788
4789llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
4790 mlir::Type dstTy = getDstTy();
4791 bool hasRelu = getRelu();
4792
4794 .Case([&](mlir::Float6E2M3FNType) {
4795 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
4796 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
4797 })
4798 .Case([&](mlir::Float6E3M2FNType) {
4799 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
4800 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
4801 })
4802 .Default([](mlir::Type) {
4803 llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
4804 return llvm::Intrinsic::not_intrinsic;
4805 });
4806}
4807
4808llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
4809 mlir::Type dstTy = getDstTy();
4810 bool hasRelu = getRelu();
4811
4813 .Case([&](mlir::Float4E2M1FNType) {
4814 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
4815 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
4816 })
4817 .Default([](mlir::Type) {
4818 llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
4819 return llvm::Intrinsic::not_intrinsic;
4820 });
4821}
4822
4823llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
4824 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
4825 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
4826 auto srcFmt = curOp.getSrcFormat();
4827 auto mc = curOp.getMulticast();
4828
4829 switch (curOp.getShape()) {
4830 case Tcgen05CpShape::SHAPE_128x256b:
4831 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
4832 case Tcgen05CpShape::SHAPE_128x128b:
4833 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
4834 case Tcgen05CpShape::SHAPE_4x256b:
4835 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
4836 case Tcgen05CpShape::SHAPE_32x128b:
4837 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
4838 case Tcgen05CpShape::SHAPE_64x128b:
4839 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
4840 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
4841 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
4842 }
4843 llvm_unreachable("Invalid shape in tcgen05 cp Op");
4844}
4845
4846// Returns the valid vector length for a given shape and vector length, the
4847// function models the table mentioned in the tcgen05.{ld, st} Op description
4848static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
4849 unsigned vecLen) {
4850 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
4851 return vecLen >= 2;
4852 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
4853 return vecLen >= 4;
4854 return true;
4855}
4856
4857LogicalResult Tcgen05LdOp::verify() {
4858 LogicalResult result = success();
4859 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4860 result = emitError("shape 16x32bx2 requires offset argument");
4861
4862 if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
4863 result = emitError("offset argument is only supported for shape 16x32bx2");
4864
4865 auto resTy = getRes().getType();
4866 unsigned resLen = isa<VectorType>(resTy)
4867 ? llvm::cast<VectorType>(resTy).getNumElements()
4868 : 1;
4869 if (!isValidVectorLength(getShape(), resLen))
4870 result = emitError(llvm::formatv("invalid result type length {0} for shape "
4871 "{1} in tcgen05.ld Op",
4872 resLen, stringifyEnum(getShape())));
4873
4874 return result;
4875}
4876
4877LogicalResult Tcgen05StOp::verify() {
4878 LogicalResult result = success();
4879 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4880 result = emitError("shape 16x32bx2 requires offset argument");
4881
4882 auto valTy = getVal().getType();
4883 unsigned valLen = isa<VectorType>(valTy)
4884 ? llvm::cast<VectorType>(valTy).getNumElements()
4885 : 1;
4886 if (!isValidVectorLength(getShape(), valLen))
4887 result = emitError(llvm::formatv("invalid input length {0} for shape "
4888 "{1} in tcgen05.st Op",
4889 valLen, stringifyEnum(getShape())));
4890
4891 return result;
4892}
4893
4894/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
4895/// have ConstantRangeAttr.
4898 SetIntRangeFn setResultRanges) {
4899 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
4900 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
4901 rangeAttr.getLower(), rangeAttr.getUpper()});
4902 } else {
4903 setResultRanges(result, IntegerValueRange::getMaxRange(result).getValue());
4904 }
4905}
4906
4907/// Verify the range attribute satisfies LLVM ConstantRange constructor
4908/// requirements for NVVM SpecialRangeableRegisterOp.
4909static LogicalResult
4911 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
4912 if (!rangeAttr)
4913 return success();
4914
4915 const llvm::APInt &lower = rangeAttr->getLower();
4916 const llvm::APInt &upper = rangeAttr->getUpper();
4917
4918 // Check LLVM ConstantRange constructor condition
4919 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
4920 unsigned bitWidth = lower.getBitWidth();
4921 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
4922 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
4923 return op->emitOpError(
4924 "invalid range attribute: Lower == Upper, but they aren't min (")
4925 << llvm::toString(minVal, 10, false) << ") or max ("
4926 << llvm::toString(maxVal, 10, false)
4927 << ") value! This is an invalid constant range.";
4928 }
4929
4930 return success();
4931}
4932
4933static llvm::Value *getAsPackedI32(llvm::Value *arg,
4934 llvm::IRBuilderBase &builder) {
4935 return builder.CreateBitCast(arg,
4936 llvm::Type::getInt32Ty(builder.getContext()));
4937}
4938
4939NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
4940 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4941 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
4942
4944 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
4945 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
4946 args.push_back(mt.lookupValue(curOp.getC()));
4947
4948 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4949 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4950 unsigned type = (isASigned << 1) | isBSigned;
4951 const llvm::Intrinsic::ID ids[] = {
4952 llvm::Intrinsic::nvvm_idp4a_u_u,
4953 llvm::Intrinsic::nvvm_idp4a_u_s,
4954 llvm::Intrinsic::nvvm_idp4a_s_u,
4955 llvm::Intrinsic::nvvm_idp4a_s_s,
4956 };
4957 return {ids[type], args};
4958}
4959
4960NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
4961 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
4962 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
4963
4965 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
4966 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
4967 args.push_back(builder.getInt1(curOp.getBHi()));
4968 args.push_back(mt.lookupValue(curOp.getC()));
4969
4970 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4971 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4972 unsigned type = (isASigned << 1) | isBSigned;
4973 const llvm::Intrinsic::ID ids[] = {
4974 llvm::Intrinsic::nvvm_idp2a_u_u,
4975 llvm::Intrinsic::nvvm_idp2a_u_s,
4976 llvm::Intrinsic::nvvm_idp2a_s_u,
4977 llvm::Intrinsic::nvvm_idp2a_s_s,
4978 };
4979 return {ids[type], args};
4980}
4981
4982static llvm::Value *getParamCastedAddr(llvm::Value *addr,
4983 llvm::IRBuilderBase &builder) {
4984 return builder.CreateAddrSpaceCast(
4985 addr, builder.getPtrTy(llvm::NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM));
4986}
4987
4989PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
4991 llvm::IRBuilderBase &builder) {
4992 using MemSpace = NVVM::NVVMMemorySpace;
4993 using CacheLevel = NVVM::PrefetchCacheLevel;
4994
4995 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
4996 std::optional<NVVM::CacheEvictionPriority> evictPriority =
4997 op.getEvictPriority();
4998 unsigned addressSpace =
4999 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
5000 .getAddressSpace();
5001
5003 llvm::Value *addr = mt.lookupValue(op.getAddr());
5004 args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
5005 : addr);
5006
5007 if (op.getTensormap())
5008 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
5009
5010 assert(cacheLevel && "expected cache level for non-tensormap prefetch");
5011
5012 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
5013 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
5014
5015 if (evictPriority && *cacheLevel == CacheLevel::L2) {
5016 switch (*evictPriority) {
5017 case NVVM::CacheEvictionPriority::EvictLast:
5018 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
5019 case NVVM::CacheEvictionPriority::EvictNormal:
5020 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
5021 default:
5022 llvm_unreachable("Invalid cache eviction priority");
5023 }
5024 }
5025
5026 switch (static_cast<MemSpace>(addressSpace)) {
5027 case MemSpace::Generic:
5028 return *cacheLevel == CacheLevel::L1
5029 ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
5030 : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
5031 case MemSpace::Global:
5032 return *cacheLevel == CacheLevel::L1
5034 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
5035 : NVVM::IDArgPair(
5036 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
5037 case MemSpace::Local:
5038 return *cacheLevel == CacheLevel::L1
5040 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
5041 : NVVM::IDArgPair(
5042 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
5043 default:
5044 llvm_unreachable("Invalid pointer address space");
5045 }
5046}
5047
5048bool NVVM::InlinePtxOp::getAsmValues(
5049 RewriterBase &rewriter,
5050 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
5051 &asmValues) {
5052 for (auto arg : getReadWriteArgs())
5053 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
5054 for (auto arg : getResults())
5055 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
5056 for (auto arg : getReadOnlyArgs())
5057 asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
5058 if (getPredicate())
5059 asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
5060 return false; // No manual mapping needed
5061}
5062
5063NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
5064 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5065 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
5067 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
5068 args.push_back(mt.lookupValue(curOp.getMbarrier()));
5069
5070 llvm::Intrinsic::ID intrinsicID =
5071 curOp.getMulticast()
5072 ? llvm::Intrinsic::
5073 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
5074 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
5075
5076 return {intrinsicID, args};
5077}
5078
5079NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
5080 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5081 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
5083 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
5084
5085 llvm::Intrinsic::ID intrinsicID;
5086
5087 switch (curOp.getQueryType()) {
5088 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
5089 intrinsicID =
5090 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
5091 break;
5092 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
5093 intrinsicID = llvm::Intrinsic::
5094 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
5095 break;
5096 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
5097 intrinsicID = llvm::Intrinsic::
5098 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
5099 break;
5100 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
5101 intrinsicID = llvm::Intrinsic::
5102 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
5103 break;
5104 }
5105 return {intrinsicID, args};
5106}
5107
5109PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
5110 llvm::IRBuilderBase &builder) {
5111 auto thisOp = cast<NVVM::PermuteOp>(op);
5112 NVVM::PermuteMode mode = thisOp.getMode();
5113
5114 static constexpr llvm::Intrinsic::ID IDs[] = {
5115 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
5116 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
5117 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
5118 llvm::Intrinsic::nvvm_prmt_rc16};
5119
5120 unsigned modeIndex = static_cast<unsigned>(mode);
5122 args.push_back(mt.lookupValue(thisOp.getLo()));
5123
5124 // Only first 3 modes (Default, f4e, b4e) need the hi operand.
5125 if (modeIndex < 3)
5126 args.push_back(mt.lookupValue(thisOp.getHi()));
5127
5128 args.push_back(mt.lookupValue(thisOp.getSelector()));
5129
5130 return {IDs[modeIndex], args};
5131}
5132
5133mlir::NVVM::IDArgPair TensormapReplaceOp::getIntrinsicIDAndArgs(
5134 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5135 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
5136
5138 args.push_back(mt.lookupValue(thisOp.getAddr()));
5139 if (thisOp.getOrd())
5140 args.push_back(builder.getInt32(thisOp.getOrd().value()));
5141 if (thisOp.getNewValue())
5142 args.push_back(mt.lookupValue(thisOp.getNewValue()));
5143 if (auto attr = thisOp.getNewValueAttr()) {
5144 auto val =
5146 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
5147 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
5148 TensormapFillModeAttr>([](auto attr) {
5149 return static_cast<unsigned>(attr.getValue());
5150 })
5151 .Default([](auto attr) {
5152 llvm_unreachable("Invalid attribute type");
5153 return 0;
5154 });
5155 args.push_back(builder.getInt32(val));
5156 }
5157
5158 static constexpr llvm::Intrinsic::ID IDs[] = {
5159 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
5160 llvm::Intrinsic::nvvm_tensormap_replace_rank,
5161 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
5162 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
5163 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
5164 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
5165 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
5166 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
5167 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
5168 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
5169 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
5170 };
5171
5172 unsigned fieldIndex = static_cast<unsigned>(thisOp.getField());
5173
5174 return {IDs[fieldIndex], args};
5175}
5176
5177//===----------------------------------------------------------------------===//
5178// NVVM tcgen05.mma functions
5179//===----------------------------------------------------------------------===//
5180
5182Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
5183 llvm::IRBuilderBase &builder) {
5184
5185 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
5187
5188 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5189
5190 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5191 const bool isATensor = isa<llvm::PointerType>(A->getType());
5192 args.push_back(A);
5193
5194 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5195 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5196 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5197
5198 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5199 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5200 using IsATensorArray = std::array<CtaGroupArray, 2>;
5201 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5202 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5203
5204 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5205 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
5206 { // without diable output lane
5207 {{// without scale input D
5208 {{
5209 // shared
5210 {{// cg1
5211 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
5212 // cg2
5213 {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
5214 {{// tensor
5215 {
5216 // cg1
5217 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5218 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5219 },
5220 {
5221 // cg2
5222 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5223 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5224 }}},
5225 }},
5226 // with scale input D
5227 {{ // shared
5228 {{// cg1
5229 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
5230 // cg2
5231 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
5232 {{// tensor
5233 {
5234 // cg1
5235 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5236 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5237 },
5238 {
5239 // cg2
5240 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5241 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5242 }}}}}}},
5243 // with disable output lane
5244 {{ // without scale input D
5245 {{ // shared
5246 {{// cg1
5247 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
5248 notIntrinsic},
5249 // cg2
5250 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
5251 notIntrinsic}}},
5252 {{// cg1
5253 {
5254 llvm::Intrinsic::
5255 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
5256 llvm::Intrinsic::
5257 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
5258 },
5259 // cg2
5260 {
5261 llvm::Intrinsic::
5262 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
5263 llvm::Intrinsic::
5264 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
5265 }}}}},
5266 // with scale input D
5267 {{ // shared
5268 {{// cg1
5269 {llvm::Intrinsic::
5270 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
5271 notIntrinsic},
5272 // cg2
5273 {llvm::Intrinsic::
5274 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
5275 notIntrinsic}}},
5276 // tensor
5277 {{// cg1
5278 {llvm::Intrinsic::
5279 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
5280 llvm::Intrinsic::
5281 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
5282 // cg2
5283 {
5284 llvm::Intrinsic::
5285 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5286 llvm::Intrinsic::
5287 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5288 }}}}}}}}};
5289
5290 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5291 bool hasScaleInputD = ScaleInputD != nullptr;
5292
5293 llvm::Value *DisableOutputLane =
5294 mt.lookupValue(thisOp.getDisableOutputLane());
5295 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5296
5297 const unsigned ctaGroup =
5298 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5299
5300 llvm::Intrinsic::ID ID =
5301 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5302 [ctaGroup - 1][thisOp.getAShift()];
5303
5304 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
5305
5306 if (hasScaleInputD)
5307 args.push_back(ScaleInputD);
5308
5309 if (hasDisableOutputLane)
5310 args.push_back(DisableOutputLane);
5311
5312 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5313
5314 if (!hasDisableOutputLane)
5315 args.push_back(builder.getInt32(ctaGroup));
5316
5317 args.push_back(
5318 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5319
5320 return {ID, args};
5321}
5322
5323static LogicalResult
5324verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
5325 NVVM::CTAGroupKind ctaGroup, bool hasAShift,
5326 NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
5327
5328 if (disableOutputLane) {
5329 mlir::VectorType disableOutputLaneType =
5330 cast<mlir::VectorType>(disableOutputLane.getType());
5331 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5332 disableOutputLaneType.getNumElements() != 4) ||
5333 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5334 disableOutputLaneType.getNumElements() != 8))
5335 return emitError(loc) << "Disable Output Lane of length "
5336 << disableOutputLaneType.getNumElements()
5337 << " is incompatible with CtaGroupAttr";
5338 }
5339
5340 if (hasAShift && !isATensor)
5341 return emitError(
5342 loc, "A-shift can be applied only when matrix A is in tensor memory");
5343
5344 if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5345 collectorOp == Tcgen05MMACollectorOp::USE))
5346 return emitError(
5347 loc, "Cannot use collector buffer operation fill or use with ashift");
5348
5349 return success();
5350}
5351
5352LogicalResult Tcgen05MMAOp::verify() {
5353 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5354 getDisableOutputLane(), getCtaGroup(), getAShift(),
5355 getCollectorOp(), getLoc());
5356}
5357
5358//===----------------------------------------------------------------------===//
5359// NVVM tcgen05.mma.sp functions
5360//===----------------------------------------------------------------------===//
5361
5362mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
5363 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5364
5365 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5367
5368 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5369
5370 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5371 bool isATensor = isa<llvm::PointerType>(A->getType());
5372 args.push_back(A);
5373
5374 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5375 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5376 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5377 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5378
5379 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5380 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5381 using IsATensorArray = std::array<CtaGroupArray, 2>;
5382 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5383 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5384
5385 // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
5386 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5387 { // without diable output lane
5388 {{// without scale input D
5389 {{
5390 // shared
5391 {{// cg1
5392 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
5393 // cg2
5394 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
5395 {{// tensor
5396 {
5397 // cg1
5398 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5399 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5400 },
5401 {
5402 // cg2
5403 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5404 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5405 }}},
5406 }},
5407 // with scale input D
5408 {{ // shared
5409 {{// cg1
5410 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5411 notIntrinsic},
5412 // cg2
5413 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5414 notIntrinsic}}},
5415 {{// tensor
5416 {
5417 // cg1
5418 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5419 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5420 },
5421 {
5422 // cg2
5423 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5424 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5425 }}}}}}},
5426 // with disable output lane
5427 {{ // without scale input D
5428 {{ // shared
5429 {{// cg1
5430 {llvm::Intrinsic::
5431 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5432 notIntrinsic},
5433 // cg2
5434 {llvm::Intrinsic::
5435 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5436 notIntrinsic}}},
5437 {{// cg1
5438 {
5439 llvm::Intrinsic::
5440 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5441 llvm::Intrinsic::
5442 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5443 },
5444 // cg2
5445 {
5446 llvm::Intrinsic::
5447 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5448 llvm::Intrinsic::
5449 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5450 }}}}},
5451 // with scale input D
5452 {{ // shared
5453 {{// cg1
5454 {llvm::Intrinsic::
5455 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5456 notIntrinsic},
5457 // cg2
5458 {llvm::Intrinsic::
5459 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5460 notIntrinsic}}},
5461 // tensor
5462 {{// cg1
5463 {llvm::Intrinsic::
5464 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5465 llvm::Intrinsic::
5466 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5467 // cg2
5468 {
5469 llvm::Intrinsic::
5470 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5471 llvm::Intrinsic::
5472 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5473 }}}}}}}}};
5474
5475 llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
5476 bool hasScaleInputD = ScaleInputD != nullptr;
5477
5478 llvm::Value *DisableOutputLane =
5479 mt.lookupValue(thisOp.getDisableOutputLane());
5480 bool hasDisableOutputLane = DisableOutputLane != nullptr;
5481
5482 unsigned ctaGroup =
5483 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
5484
5485 llvm::Intrinsic::ID ID =
5486 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5487 [ctaGroup - 1][thisOp.getAShift()];
5488
5489 assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
5490
5491 if (hasScaleInputD)
5492 args.push_back(ScaleInputD);
5493
5494 if (hasDisableOutputLane)
5495 args.push_back(DisableOutputLane);
5496
5497 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5498
5499 if (!hasDisableOutputLane)
5500 args.push_back(builder.getInt32(ctaGroup));
5501
5502 args.push_back(
5503 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5504
5505 return {ID, args};
5506}
5507
5508LogicalResult Tcgen05MMASparseOp::verify() {
5509 return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
5510 getDisableOutputLane(), getCtaGroup(), getAShift(),
5511 getCollectorOp(), getLoc());
5512}
5513
5514//===----------------------------------------------------------------------===//
5515// NVVM tcgen05.mma.block_scale functions
5516//===----------------------------------------------------------------------===//
5517
5518mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
5519 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5520
5521 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5523
5524 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5525
5526 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5527 bool isATensor = isa<llvm::PointerType>(A->getType());
5528 args.push_back(A);
5529
5530 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5531 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5532 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5533 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5534 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5535 args.push_back(builder.getInt32(
5536 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5537 args.push_back(
5538 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5539
5540 auto kind = thisOp.getKind();
5541 auto blockScale = thisOp.getBlockScale();
5542 llvm::Intrinsic::ID ID = [&]() {
5543 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5544 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5545 return isATensor ? llvm::Intrinsic::
5546 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5547 : llvm::Intrinsic::
5548 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5549 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5550 return isATensor
5551 ? llvm::Intrinsic::
5552 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5553 : llvm::Intrinsic::
5554 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5555 }
5556 } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5557 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5558 return isATensor
5559 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5560 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5561 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5562 return isATensor ? llvm::Intrinsic::
5563 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5564 : llvm::Intrinsic::
5565 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5566 }
5567 } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5568 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5569 return isATensor
5570 ? llvm::Intrinsic::
5571 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5572 : llvm::Intrinsic::
5573 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5574
5575 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5576 return isATensor
5577 ? llvm::Intrinsic::
5578 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5579 : llvm::Intrinsic::
5580 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5581 }
5582 }
5583 llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
5584 }();
5585
5586 return {ID, args};
5587}
5588
5589static LogicalResult verifyTcgen05MMABlockScaleOp(
5590 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
5591 NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
5592 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5593 kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
5594 return emitError(loc, "mxf4nvf4 requires block scale attribute");
5595
5596 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5597 kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
5598 return emitError(loc,
5599 llvm::formatv("{} kind does not support block16 attribute",
5600 stringifyEnum(kind)));
5601
5602 return success();
5603}
5604
5605LogicalResult Tcgen05MMABlockScaleOp::verify() {
5606 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5607 getBlockScale(), getLoc());
5608}
5609
5610//===----------------------------------------------------------------------===//
5611// NVVM tcgen05.mma.sp.block_scale functions
5612//===----------------------------------------------------------------------===//
5613
5614mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
5615 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5616
5617 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5619
5620 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5621
5622 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5623 bool isATensor = isa<llvm::PointerType>(A->getType());
5624 args.push_back(A);
5625
5626 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5627 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5628 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5629 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5630 args.push_back(mt.lookupValue(thisOp.getScaleA()));
5631 args.push_back(mt.lookupValue(thisOp.getScaleB()));
5632 args.push_back(builder.getInt32(
5633 static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
5634 args.push_back(
5635 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5636
5637 auto kind = thisOp.getKind();
5638 auto blockScale = thisOp.getBlockScale();
5639 llvm::Intrinsic::ID ID = [&]() {
5640 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5641 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5642 return isATensor ? llvm::Intrinsic::
5643 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5644 : llvm::Intrinsic::
5645 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5646 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5647 return isATensor
5648 ? llvm::Intrinsic::
5649 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5650 : llvm::Intrinsic::
5651 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5652 }
5653 } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5654 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5655 return isATensor ? llvm::Intrinsic::
5656 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5657 : llvm::Intrinsic::
5658 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5659 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5660 return isATensor
5661 ? llvm::Intrinsic::
5662 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5663 : llvm::Intrinsic::
5664 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5665 }
5666 } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5667 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5668 return isATensor
5669 ? llvm::Intrinsic::
5670 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5671 : llvm::Intrinsic::
5672 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5673
5674 } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5675 return isATensor
5676 ? llvm::Intrinsic::
5677 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5678 : llvm::Intrinsic::
5679 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5680 }
5681 }
5682 llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
5683 }();
5684
5685 return {ID, args};
5686}
5687
5688LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5689 return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
5690 getBlockScale(), getLoc());
5691}
5692
5693//===----------------------------------------------------------------------===//
5694// NVVM tcgen05.mma.ws functions
5695//===----------------------------------------------------------------------===//
5696
5697mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
5698 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5699
5700 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5702
5703 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5704
5705 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5706 bool isATensor = isa<llvm::PointerType>(A->getType());
5707 args.push_back(A);
5708
5709 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5710 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5711 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5712
5713 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5714 llvm::Intrinsic::ID ID = notIntrinsic;
5715 if (ZeroColMask) {
5716 args.push_back(mt.lookupValue(ZeroColMask));
5717 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
5718 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
5719 } else
5720 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
5721 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
5722
5723 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5724 args.push_back(
5725 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5726 args.push_back(
5727 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5728
5729 return {ID, args};
5730}
5731
5732//===----------------------------------------------------------------------===//
5733// NVVM tcgen05.mma.ws.sp functions
5734//===----------------------------------------------------------------------===//
5735
5736mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
5737 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5738
5739 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
5741
5742 args.push_back(mt.lookupValue(thisOp.getMatrixD()));
5743
5744 llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
5745 bool isATensor = isa<llvm::PointerType>(A->getType());
5746 args.push_back(A);
5747
5748 args.push_back(mt.lookupValue(thisOp.getMatrixB()));
5749 args.push_back(mt.lookupValue(thisOp.getIdesc()));
5750 args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
5751 args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
5752
5753 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5754 llvm::Intrinsic::ID ID = notIntrinsic;
5755 if (ZeroColMask) {
5756 args.push_back(mt.lookupValue(ZeroColMask));
5757 ID = isATensor
5758 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
5759 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
5760 } else
5761 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
5762 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
5763
5764 args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
5765 args.push_back(
5766 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5767 args.push_back(
5768 builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
5769
5770 return {ID, args};
5771}
5772
5773//===----------------------------------------------------------------------===//
5774// NVVM tcgen05.ld.red functions
5775//===----------------------------------------------------------------------===//
5776
5777#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
5778 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
5779
5780mlir::NVVM::IDArgPair NVVM::Tcgen05LdRedOp::getIntrinsicIDAndArgs(
5781 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
5782 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
5784
5785 mlir::VectorType VecResTy =
5786 cast<mlir::VectorType>(thisOp.getData().getType());
5787 unsigned Num = VecResTy.getNumElements();
5788 bool IsFloat = thisOp.getRedVal().getType().isF32();
5789
5790 llvm::Intrinsic::ID Shape32x32b[][2] = {
5792 {TCGEN05LDRED(32x32b, x2, i32), TCGEN05LDRED(32x32b, x2, f32)},
5793 {TCGEN05LDRED(32x32b, x4, i32), TCGEN05LDRED(32x32b, x4, f32)},
5794 {TCGEN05LDRED(32x32b, x8, i32), TCGEN05LDRED(32x32b, x8, f32)},
5795 {TCGEN05LDRED(32x32b, x16, i32), TCGEN05LDRED(32x32b, x16, f32)},
5796 {TCGEN05LDRED(32x32b, x32, i32), TCGEN05LDRED(32x32b, x32, f32)},
5797 {TCGEN05LDRED(32x32b, x64, i32), TCGEN05LDRED(32x32b, x64, f32)},
5798 {TCGEN05LDRED(32x32b, x128, i32), TCGEN05LDRED(32x32b, x128, f32)},
5799 };
5800
5801 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
5803 {TCGEN05LDRED(16x32bx2, x2, i32), TCGEN05LDRED(16x32bx2, x2, f32)},
5804 {TCGEN05LDRED(16x32bx2, x4, i32), TCGEN05LDRED(16x32bx2, x4, f32)},
5805 {TCGEN05LDRED(16x32bx2, x8, i32), TCGEN05LDRED(16x32bx2, x8, f32)},
5806 {TCGEN05LDRED(16x32bx2, x16, i32), TCGEN05LDRED(16x32bx2, x16, f32)},
5807 {TCGEN05LDRED(16x32bx2, x32, i32), TCGEN05LDRED(16x32bx2, x32, f32)},
5808 {TCGEN05LDRED(16x32bx2, x64, i32), TCGEN05LDRED(16x32bx2, x64, f32)},
5809 {TCGEN05LDRED(16x32bx2, x128, i32), TCGEN05LDRED(16x32bx2, x128, f32)},
5810 };
5811
5812 NVVM::Tcgen05LdStShape shape = thisOp.getShape();
5813 unsigned ID = [&]() {
5814 // `num` contains the length of vector and log2 of `num` returns the index
5815 // into the shape array
5816 unsigned idx = std::log2(Num);
5817 switch (shape) {
5818 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
5819 return Shape32x32b[idx][IsFloat];
5820 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
5821 return Shape16x32bx2[idx][IsFloat];
5822 default:
5823 llvm_unreachable("unhandled tcgen05.ld lowering");
5824 }
5825 }();
5826
5827 args.push_back(mt.lookupValue(thisOp.getAddr()));
5828
5829 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
5830 args.push_back(mt.lookupValue(thisOp.getOffset()));
5831
5832 args.push_back(
5833 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
5834
5835 if (IsFloat) {
5836 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getAbs())));
5837 args.push_back(builder.getInt1(static_cast<unsigned>(thisOp.getNan())));
5838 }
5839 return {ID, args};
5840}
5841
5842LogicalResult Tcgen05LdRedOp::verify() {
5843 VectorType data = cast<VectorType>(getData().getType());
5844 Type redVal = getRedVal().getType();
5845
5846 if (data.getElementType() != redVal)
5847 return emitError(
5848 "type of reduction value and element type of vector data should match");
5849
5850 if (getOp() != NVVM::ReductionKind::MIN &&
5851 getOp() != NVVM::ReductionKind::MAX)
5852 return emitError("only min and max reduction kinds are supported");
5853
5854 if (redVal.isInteger() && (getAbs() || getNan())) {
5855 return emitError("abs or nan is only applicable for f32 type");
5856 }
5857 return success();
5858}
5859
5860//===----------------------------------------------------------------------===//
5861// NVVMDialect initialization, type parsing, and registration.
5862//===----------------------------------------------------------------------===//
5863
5864// TODO: This should be the llvm.nvvm dialect once this is supported.
5865void NVVMDialect::initialize() {
5866 addOperations<
5867#define GET_OP_LIST
5868#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5869 >();
5870 addAttributes<
5871#define GET_ATTRDEF_LIST
5872#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
5873 >();
5874
5875 // Support unknown operations because not all NVVM operations are
5876 // registered.
5877 allowUnknownOperations();
5878 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
5879 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
5880}
5881
5882LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
5883 NamedAttribute attr) {
5884 StringAttr attrName = attr.getName();
5885 // Kernel function attribute should be attached to functions.
5886 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
5887 if (!isa<LLVM::LLVMFuncOp>(op)) {
5888 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
5889 << "' attribute attached to unexpected op";
5890 }
5891 }
5892 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
5893 // dim
5894 if (attrName == NVVMDialect::getMaxntidAttrName() ||
5895 attrName == NVVMDialect::getReqntidAttrName() ||
5896 attrName == NVVMDialect::getClusterDimAttrName()) {
5897 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
5898 if (!values || values.empty() || values.size() > 3) {
5899 return op->emitError()
5900 << "'" << attrName
5901 << "' attribute must be integer array with maximum 3 index";
5902 }
5903 }
5904 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
5905 // attribute
5906 if (attrName == NVVMDialect::getMinctasmAttrName() ||
5907 attrName == NVVMDialect::getMaxnregAttrName() ||
5908 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
5909 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
5910 return op->emitError()
5911 << "'" << attrName << "' attribute must be integer constant";
5912 }
5913 }
5914 // blocksareclusters must be used along with reqntid and cluster_dim
5915 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
5916 if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
5917 !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
5918 return op->emitError()
5919 << "'" << attrName << "' attribute must be used along with "
5920 << "'" << NVVMDialect::getReqntidAttrName() << "' and "
5921 << "'" << NVVMDialect::getClusterDimAttrName() << "'";
5922 }
5923 }
5924
5925 return success();
5926}
5927
5928LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
5929 unsigned regionIndex,
5930 unsigned argIndex,
5931 NamedAttribute argAttr) {
5932 auto funcOp = dyn_cast<FunctionOpInterface>(op);
5933 if (!funcOp)
5934 return success();
5935
5936 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
5937 StringAttr attrName = argAttr.getName();
5938 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
5939 if (!isKernel) {
5940 return op->emitError()
5941 << "'" << attrName
5942 << "' attribute must be present only on kernel arguments";
5943 }
5944 if (!isa<UnitAttr>(argAttr.getValue()))
5945 return op->emitError() << "'" << attrName << "' must be a unit attribute";
5946 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
5947 return op->emitError()
5948 << "'" << attrName
5949 << "' attribute requires the argument to also have attribute '"
5950 << LLVM::LLVMDialect::getByValAttrName() << "'";
5951 }
5952 }
5953
5954 return success();
5955}
5956
5957//===----------------------------------------------------------------------===//
5958// NVVM Address Space Attr
5959//===----------------------------------------------------------------------===//
5960
5961unsigned NVVMMemorySpaceAttr::getAddressSpace() const {
5962 return static_cast<unsigned>(getValue());
5963}
5964
5965bool NVVMMemorySpaceAttr::isValidLoad(
5966 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5967 const ::mlir::DataLayout *dataLayout,
5969 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
5970 dataLayout, emitError);
5971}
5972
5973bool NVVMMemorySpaceAttr::isValidStore(
5974 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5975 const ::mlir::DataLayout *dataLayout,
5977 return LLVM::detail::isValidLoadStoreImpl(type, ordering, alignment,
5978 dataLayout, emitError);
5979}
5980
5981bool NVVMMemorySpaceAttr::isValidAtomicOp(
5982 ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
5983 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
5985 // TODO: update this method once `ptr.atomic_rmw` is implemented.
5986 assert(false && "unimplemented, see TODO in the source.");
5987 return false;
5988}
5989
5990bool NVVMMemorySpaceAttr::isValidAtomicXchg(
5991 Type type, ptr::AtomicOrdering successOrdering,
5992 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
5993 const ::mlir::DataLayout *dataLayout,
5995 // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
5996 assert(false && "unimplemented, see TODO in the source.");
5997 return false;
5998}
5999
6000bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
6001 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
6002 // TODO: update this method once the `ptr.addrspace_cast` op is added to the
6003 // dialect.
6004 assert(false && "unimplemented, see TODO in the source.");
6005 return false;
6006}
6007
6008bool NVVMMemorySpaceAttr::isValidPtrIntCast(
6009 Type intLikeTy, Type ptrLikeTy,
6011 // TODO: update this method once the int-cast ops are added to the `ptr`
6012 // dialect.
6013 assert(false && "unimplemented, see TODO in the source.");
6014 return false;
6015}
6016
6017//===----------------------------------------------------------------------===//
6018// NVVM target attribute.
6019//===----------------------------------------------------------------------===//
6020LogicalResult
6021NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
6022 int optLevel, StringRef triple, StringRef chip,
6023 StringRef features, DictionaryAttr flags,
6024 ArrayAttr files, bool verifyTarget) {
6025 if (optLevel < 0 || optLevel > 3) {
6026 emitError() << "The optimization level must be a number between 0 and 3.";
6027 return failure();
6028 }
6029 if (triple.empty()) {
6030 emitError() << "The target triple cannot be empty.";
6031 return failure();
6032 }
6033 if (chip.empty()) {
6034 emitError() << "The target chip cannot be empty.";
6035 return failure();
6036 }
6037 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
6038 return mlir::isa_and_nonnull<StringAttr>(attr);
6039 })) {
6040 emitError() << "All the elements in the `link` array must be strings.";
6041 return failure();
6042 }
6043 return success();
6044}
6045
6046LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
6047 if (!getVerifyTarget())
6048 return success();
6049
6050 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
6051 if (!gpuModuleOp) {
6052 return emitError(gpuModule->getLoc(),
6053 "NVVM target attribute must be attached to a GPU module");
6054 }
6055
6056 const NVVMCheckSMVersion targetSMVersion =
6058 if (!targetSMVersion.isMinimumSMVersion()) {
6059 return emitError(gpuModule->getLoc(),
6060 "Minimum NVVM target SM version is sm_20");
6061 }
6062
6063 if (gpuModuleOp
6064 ->walk([&](Operation *op) {
6065 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
6066 const NVVMCheckSMVersion requirement =
6067 reqOp.getRequiredMinSMVersion();
6068 if (!requirement.isCompatibleWith(targetSMVersion)) {
6069 op->emitOpError() << "is not supported on " << getChip();
6070 return WalkResult::interrupt();
6071 }
6072 }
6073 return WalkResult::advance();
6074 })
6075 .wasInterrupted())
6076 return failure();
6077
6078 return success();
6079}
6080
6081#define GET_OP_CLASSES
6082#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6083
6084#define GET_ATTRDEF_CLASSES
6085#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
#define _none
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyAddSubFOp(OpType op)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &regs)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > &regTypes, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerType getI16Type()
Definition Builders.cpp:65
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:586
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
Definition LLVMAttrs.cpp:60
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Definition NVVMDialect.h:53
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
LogicalResult matchAndRewrite(SubFOp op, PatternRewriter &rewriter) const override
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.