MLIR 22.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
26#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34#define DEBUG_TYPE "nvgpu-to-nvvm"
35
36namespace mlir {
37#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38#include "mlir/Conversion/Passes.h.inc"
39} // namespace mlir
40
41using namespace mlir;
42
43/// Number of bits that needs to be excluded when building matrix descriptor for
44/// wgmma operations.
45constexpr int exclude4LSB = 4;
46
47/// GPU has 32 bit registers, this function truncates values when larger width
48/// is not needed.
50 Type type = value.getType();
51 assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
52 if (type.getIntOrFloatBitWidth() <= 32)
53 return value;
54 return LLVM::TruncOp::create(b, b.getI32Type(), value);
55}
56
57/// Returns the type for the intrinsic given the vectorResultType of the
58/// `gpu.mma.sync` operation.
59static Type inferIntrinsicResultType(Type vectorResultType) {
60 MLIRContext *ctx = vectorResultType.getContext();
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63 auto i32Ty = IntegerType::get(ctx, 32);
64 auto i32x2Ty = VectorType::get(2, i32Ty);
65 Type f64Ty = Float64Type::get(ctx);
66 Type f64x2Ty = VectorType::get(2, f64Ty);
67 Type f32Ty = Float32Type::get(ctx);
68 Type f32x2Ty = VectorType::get(2, f32Ty);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
71 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
72 }
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
75 ctx,
76 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
77 }
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
80 }
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
83 ctx,
84 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
85 }
86 if (a.getElementType() == VectorType::get(1, f32Ty)) {
87 return LLVM::LLVMStructType::getLiteral(
88 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
89 }
90 return vectorResultType;
91}
92
93/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
94/// always an LLVM struct) into a fragment that is compatible with the vector
95/// type of this operation. This involves extracting elements from the struct
96/// and inserting them into an LLVM array. These extra data-movement
97/// operations should be canonicalized away by the LLVM backend.
98static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
99 Type resultType, Value intrinsicResult,
100 RewriterBase &rewriter) {
101 MLIRContext *ctx = rewriter.getContext();
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
104 Type i32Ty = rewriter.getI32Type();
105 Type f32Ty = rewriter.getF32Type();
106 Type f64Ty = rewriter.getF64Type();
107 Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
108 Type i32x2Ty = VectorType::get(2, i32Ty);
109 Type f64x2Ty = VectorType::get(2, f64Ty);
110 Type f32x2Ty = VectorType::get(2, f32Ty);
111 Type f32x1Ty = VectorType::get(1, f32Ty);
112
113 auto makeConst = [&](int32_t index) -> Value {
114 return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
115 rewriter.getI32IntegerAttr(index));
116 };
117
118 if (arrayType) {
119 SmallVector<Value, 4> elements;
120
121 // The intrinsic returns 32-bit wide elements in a form which can be
122 // directly bitcasted and inserted into the result vector.
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (unsigned i = 0; i < structType.getBody().size(); i++) {
126 Value el =
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
128 el = rewriter.createOrFold<LLVM::BitcastOp>(
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
131 }
132 }
133
134 // The intrinsic returns i32, f64, and f32 values as individual scalars,
135 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
136 // need to extract them from the struct and pack them into the 64-bit wide
137 // rows of the vector result.
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
141
142 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
143 Value vec =
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
145 Value x1 =
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
148 i * 2 + 1);
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
150 x1, makeConst(0));
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
152 x2, makeConst(1));
153 elements.push_back(vec);
154 }
155 }
156
157 // Create the final vectorized result.
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159 for (const auto &el : llvm::enumerate(elements)) {
160 result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
161 el.index());
162 }
163 return result;
164 }
165
166 return intrinsicResult;
167}
168
169/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
170/// given as 2D `vectors` where the rows are 32b or 64b wide. The
171/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
172/// scalars of certain types. This function helps unpack the `vector` arguments
173/// and cast them to the types expected by `nvvm.mma.sync`.
175 Value operand,
176 NVVM::MMATypes operandPtxType) {
178 Type i32Ty = b.getI32Type();
179 Type f64Ty = b.getF64Type();
180 Type f32Ty = b.getF32Type();
181 Type i64Ty = b.getI64Type();
182 Type i8x4Ty = VectorType::get(4, b.getI8Type());
183 Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
184 Type f32x1Ty = VectorType::get(1, f32Ty);
185 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
186
187 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188 Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
189
190 // For 4xi8 vectors, the intrinsic expects these to be provided as i32
191 // scalar types.
192 if (arrayTy.getElementType() == i8x4Ty ||
193 arrayTy.getElementType() == i4x8Ty ||
194 (arrayTy.getElementType() == f32x1Ty &&
195 operandPtxType == NVVM::MMATypes::tf32)) {
196 result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
197 continue;
198 }
199
200 // For some element types (i32, f32, f64), we need to unpack the inner
201 // vector/array type as well because the intrinsic expects individual
202 // scalars to be provided.
203 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205 innerArrayTy.getElementType() == f64Ty ||
206 innerArrayTy.getElementType() == f32Ty)) {
207 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208 idx < innerSize; idx++) {
209 result.push_back(LLVM::ExtractElementOp::create(
210 b, toUse,
211 LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
212 }
213 continue;
214 }
215 result.push_back(toUse);
216 }
217 return result;
218}
219
220/// Returns whether mbarrier object has shared memory address space.
221static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
222 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223 barrierType.getMemorySpace()));
224}
225
226/// Returns the memory space attribute of the mbarrier object.
228 nvgpu::MBarrierGroupType barrierType) {
229 Attribute memorySpace = {};
230 if (isMbarrierShared(barrierType)) {
231 memorySpace =
232 IntegerAttr::get(IntegerType::get(context, 64),
233 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
234 }
235 return memorySpace;
236}
237
238/// Returns memref type of the mbarrier object. The type is defined in the
239/// MBarrierGroupType.
240MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
241 nvgpu::MBarrierGroupType barrierType) {
242 Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
243 MemRefLayoutAttrInterface layout;
244 return MemRefType::get({barrierType.getNumBarriers()},
245 IntegerType::get(context, 64), layout, memorySpace);
246}
247
248namespace {
249
250struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
251 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
252
253 LogicalResult
254 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter) const override {
256 MLIRContext *ctx = getContext();
257 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
258
259 // The result type of ldmatrix will always be a struct of 32bit integer
260 // registers if more than one 32bit value is returned. Otherwise, the result
261 // is a single i32. The result type of the GPU operation is always a vector
262 // of shape (NumRegisters, VectorRegister) where VectorRegister is the
263 // vector type of the result and always 32 bits long. We bitcast the result
264 // of the NVVM::LdMatrix to this vector type.
265 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266 if (!vectorResultType) {
267 return failure();
268 }
269 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
270 vectorResultType.getElementType());
271
272 int64_t num32BitRegs = vectorResultType.getDimSize(0);
273
274 Type ldMatrixResultType;
275 if (num32BitRegs > 1) {
276 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
277 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
278 } else {
279 ldMatrixResultType = rewriter.getI32Type();
280 }
281
282 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
283 Value srcPtr =
284 getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
285 adaptor.getSrcMemref(), adaptor.getIndices());
286 auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
289 /*num=*/op.getNumTiles(),
290 /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row,
292 /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
293
294 // The ldmatrix operation returns either a single i32 value or a struct of
295 // i32 values. Here we unpack those values and cast them back to their
296 // actual vector type (still of width 32b) and repack them into a result
297 // struct.
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = LLVM::PoisonOp::create(b, finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301 Value i32Register =
302 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
303 : ldMatrixResult;
304 Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
305 result = LLVM::InsertValueOp::create(b, result, casted, i);
306 }
307
308 rewriter.replaceOp(op, result);
309 return success();
310 }
311};
312
313/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314/// enum).
315static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316 Type elType = getElementTypeOrSelf(t);
317 if (elType.isInteger(8))
318 return NVVM::MMATypes::s8;
319 if (elType.isInteger(4))
320 return NVVM::MMATypes::s4;
321 if (elType.isF16())
322 return NVVM::MMATypes::f16;
323 if (elType.isF64())
324 return NVVM::MMATypes::f64;
325 if (elType.isF32())
326 return NVVM::MMATypes::tf32;
327 return failure();
328}
329
330struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
332
333 LogicalResult
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override {
336 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337 // Get the shapes of the MMAMatrix type being used. The shapes will
338 // choose which intrinsic this op will be lowered to.
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
342
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344
345 // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
348 return failure();
349
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351 if (failed(ptxTypeA))
352 return op->emitOpError("failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354 if (failed(ptxTypeB))
355 return op->emitOpError("failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358 /*isAccumulator=*/true);
359 if (!ptxTypeC)
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
362
363 // TODO: add an attribute to the op to customize this behavior.
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
367
368 SmallVector<Value> matA =
369 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370 SmallVector<Value> matB =
371 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372 SmallVector<Value> matC =
373 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376 Type intrinsicResTy = inferIntrinsicResultType(
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult =
379 NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
380 /*shape=*/gemmShape,
381 /*b1Op=*/std::nullopt,
382 /*intOverflow=*/overflow,
383 /*multiplicandPtxTypes=*/
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385 /*multiplicandLayouts=*/
386 std::array<NVVM::MMALayout, 2>{
387 NVVM::MMALayout::row, NVVM::MMALayout::col});
388 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389 desiredRetTy, intrinsicResult,
390 rewriter));
391 return success();
392 }
393};
394
395struct ConvertNVGPUToNVVMPass
396 : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397 using Base::Base;
398
399 void runOnOperation() override {
400 LowerToLLVMOptions options(&getContext());
401 RewritePatternSet patterns(&getContext());
402 LLVMTypeConverter converter(&getContext(), options);
403 IRRewriter rewriter(&getContext());
405 converter, [](gpu::AddressSpace space) -> unsigned {
406 switch (space) {
407 case gpu::AddressSpace::Global:
408 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
409 case gpu::AddressSpace::Workgroup:
410 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
411 case gpu::AddressSpace::Private:
412 return 0;
413 }
414 llvm_unreachable("unknown address space enum value");
415 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
416 });
417 /// device-side async tokens cannot be materialized in nvvm. We just
418 /// convert them to a dummy i32 type in order to easily drop them during
419 /// conversion.
420 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
421 return converter.convertType(IntegerType::get(type.getContext(), 32));
422 });
423 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
424 Type elemType = type.getFragmented().getElementType();
425 int64_t sizeM = type.getFragmented().getDimSize(0);
426 int64_t sizeN = type.getFragmented().getDimSize(1);
427
428 unsigned numMembers;
429 if (elemType.isF32() || elemType.isInteger(32))
430 numMembers = sizeN / 2;
431 else if (elemType.isF16())
432 numMembers = sizeN / 4;
433 else
434 llvm_unreachable("unsupported type for warpgroup accumulator");
435
436 SmallVector<Type> innerStructBody;
437 for (unsigned i = 0; i < numMembers; i++)
438 innerStructBody.push_back(elemType);
439 auto innerStructType =
440 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
441
442 SmallVector<Type> structBody;
443 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
444 structBody.push_back(innerStructType);
445
446 auto convertedType =
447 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
448 return converter.convertType(convertedType);
449 });
450 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
451 return converter.convertType(IntegerType::get(type.getContext(), 64));
452 });
453 converter.addConversion(
454 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
455 return converter.convertType(IntegerType::get(type.getContext(), 64));
456 });
457 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
458 return converter.convertType(
459 nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
460 });
461 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
462 return LLVM::LLVMPointerType::get(type.getContext());
463 });
465 LLVMConversionTarget target(getContext());
466 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
467 target.addLegalDialect<::mlir::arith::ArithDialect>();
468 target.addLegalDialect<::mlir::memref::MemRefDialect>();
469 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
471 converter, patterns, target);
472 if (failed(applyPartialConversion(getOperation(), target,
473 std::move(patterns))))
474 signalPassFailure();
475 }
476};
477
478/// Returns the constraints for the sparse MMA inline assembly instruction.
479static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
480 unsigned matBSize,
481 unsigned matCSize) {
482 std::string str;
483 llvm::raw_string_ostream ss(str);
484 for (unsigned i = 0; i < matCSize; i++)
485 ss << "=r,";
486 for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
487 ss << "r,";
488 // The final operand is for the sparsity metadata.
489 // The sparsity selector appears as direct literal.
490 ss << "r";
491 return str;
492}
493
494/// Returns the string for the `mma.sp.sync` instruction that corresponds to
495/// the given parameters. Note that this function doesn't do any validation,
496/// it's expected that the provided parameters correspond to a valid
497/// instruction.
498static std::string buildMmaSparseAsmString(
499 const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
500 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
501 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
502 std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
503 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
504 return NVVM::stringifyMMATypes(ptxType);
505 };
506
507 std::string asmStr;
508 llvm::raw_string_ostream ss(asmStr);
509 ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
510 << shape[2] << ".row.col.";
511
512 if (overflow)
513 ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
514
515 ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
516 << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
517 unsigned asmArgIdx = 0;
518
519 // The operand string is structured into sections `{matC elements...},
520 // {matA elements...}, {matB elements...}, {matC elements}`.
521 for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
522 ss << "{";
523 for (unsigned i = 0; i < arrSize; i++)
524 ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
525 ss << "},";
526 }
527 ss << "$" << asmArgIdx++ << ",";
528 assert(metaDataSelector <= 1);
529 ss << "0x" << metaDataSelector << ";";
530 return asmStr;
531}
532
533/// Builds an inline assembly operation corresponding to the specified MMA
534/// sparse sync operation.
535static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
536 ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
537 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
538 std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
539 ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
540 int64_t metadataSelector, const std::array<int64_t, 3> &shape,
541 Type intrinsicResultType) {
542 auto asmDialectAttr =
543 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
544
545 const unsigned matASize = unpackedAData.size();
546 const unsigned matBSize = unpackedB.size();
547 const unsigned matCSize = unpackedC.size();
548
549 std::string asmStr = buildMmaSparseAsmString(
550 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
551 ptxTypeD, overflow, metadataSelector);
552 std::string constraintStr =
553 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
554
555 SmallVector<Value> asmVals;
556 asmVals.reserve(matASize + matBSize + matCSize + 1);
557 for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
558 llvm::append_range(asmVals, args);
559 asmVals.push_back(indexData);
560
561 return LLVM::InlineAsmOp::create(b,
562 /*resultTypes=*/intrinsicResultType,
563 /*operands=*/asmVals,
564 /*asm_string=*/asmStr,
565 /*constraints=*/constraintStr,
566 /*has_side_effects=*/true,
567 /*is_align_stack=*/false,
568 LLVM::TailCallKind::None,
569 /*asm_dialect=*/asmDialectAttr,
570 /*operand_attrs=*/ArrayAttr());
571}
572
573/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
574struct NVGPUMmaSparseSyncLowering
575 : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
576 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
577
578 LogicalResult
579 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
580 ConversionPatternRewriter &rewriter) const override {
581 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
582 // Get the shapes of the MMAMatrix type being used. The shapes will
583 // choose which intrinsic this op will be lowered to.
584 VectorType aType = op.getMatrixA().getType();
585 VectorType bType = op.getMatrixB().getType();
586 VectorType cType = op.getMatrixC().getType();
587
588 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
589 if (failed(ptxTypeA))
590 return op->emitOpError("failed to deduce operand PTX types");
591 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
592 if (failed(ptxTypeB))
593 return op->emitOpError("failed to deduce operand PTX types");
594 std::optional<NVVM::MMATypes> ptxTypeC =
595 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
596 /*isAccumulator=*/true);
597 if (!ptxTypeC)
598 return op->emitError(
599 "could not infer the PTX type for the accumulator/result");
600
601 // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
602 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
603 if (aType.getElementType().isF32() && !tf32Enabled)
604 return failure();
605
606 // TODO: add an attribute to the op to customize this behavior.
607 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
608 if (isa<IntegerType>(aType.getElementType()))
609 overflow = NVVM::MMAIntOverflow::satfinite;
610
611 SmallVector<Value> matA =
612 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
613 SmallVector<Value> matB =
614 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
615 SmallVector<Value> matC =
616 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
617
618 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
619 Type intrinsicResTy = inferIntrinsicResultType(
620 typeConverter->convertType(op->getResultTypes()[0]));
621
622 // Bitcast the sparse metadata from vector<2xf16> to an i32.
623 Value sparseMetadata = adaptor.getSparseMetadata();
624 if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
625 return op->emitOpError() << "Expected metadata type to be LLVM "
626 "VectorType of 2 i16 elements";
627 sparseMetadata =
628 LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
629
630 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
631 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
632 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
633 intrinsicResTy);
634 if (failed(intrinsicResult))
635 return failure();
636
637 assert((*intrinsicResult).getNumResults() == 1 &&
638 "expected inline asm op returns a single LLVM struct type");
639 rewriter.replaceOp(
640 op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
641 (*intrinsicResult)->getResult(0), rewriter));
642 return success();
643 }
644};
645
646struct NVGPUAsyncCopyLowering
647 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
648 using ConvertOpToLLVMPattern<
649 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
650
651 LogicalResult
652 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
653 ConversionPatternRewriter &rewriter) const override {
654 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
655 Location loc = op.getLoc();
656 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
657 Value dstPtr =
658 getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
659 adaptor.getDst(), adaptor.getDstIndices());
660 FailureOr<unsigned> dstAddressSpace =
661 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
662 if (failed(dstAddressSpace))
663 return rewriter.notifyMatchFailure(
664 loc, "destination memref address space not convertible to integer");
665
666 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
667 FailureOr<unsigned> srcAddressSpace =
668 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
669 if (failed(srcAddressSpace))
670 return rewriter.notifyMatchFailure(
671 loc, "source memref address space not convertible to integer");
672
673 Value scrPtr =
674 getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
675 adaptor.getSrcIndices());
676 // Intrinsics takes a global pointer so we need an address space cast.
677 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
678 op->getContext(), static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
679 scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
680 int64_t dstElements = adaptor.getDstElements().getZExtValue();
681 int64_t sizeInBytes =
682 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
683 // When the optional SrcElements argument is *not* present, the regular
684 // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
685 // memory) to fill DstElements number of elements in the destination
686 // (shared memory).
687 Value srcBytes = adaptor.getSrcElements();
688 if (srcBytes) {
689 // When the optional SrcElements argument is present, the source (global
690 // memory) of CpAsyncOp is read only for SrcElements number of elements.
691 // The rest of the DstElements in the destination (shared memory) are
692 // filled with zeros.
693 Value c3I32 =
694 LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
695 Value bitwidth = LLVM::ConstantOp::create(
696 b, b.getI32Type(),
697 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
698 Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
699 srcBytes = LLVM::LShrOp::create(
700 b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
701 }
702 // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
703 // 16 dst bytes.
704 NVVM::LoadCacheModifierKind cacheModifier =
705 (op.getBypassL1().value_or(false) && sizeInBytes == 16)
706 ? NVVM::LoadCacheModifierKind::CG
707 : NVVM::LoadCacheModifierKind::CA;
708
709 NVVM::CpAsyncOp::create(
710 b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
711 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
712 srcBytes);
713
714 // Drop the result token.
715 Value zero =
716 LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
717 rewriter.getI32IntegerAttr(0));
718 rewriter.replaceOp(op, zero);
719 return success();
720 }
721};
722
723struct NVGPUAsyncCreateGroupLowering
724 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
725 using ConvertOpToLLVMPattern<
726 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
727
728 LogicalResult
729 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
730 ConversionPatternRewriter &rewriter) const override {
731 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
732 // Drop the result token.
733 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
734 IntegerType::get(op.getContext(), 32),
735 rewriter.getI32IntegerAttr(0));
736 rewriter.replaceOp(op, zero);
737 return success();
738 }
739};
740
741struct NVGPUAsyncWaitLowering
742 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
743 using ConvertOpToLLVMPattern<
744 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
745
746 LogicalResult
747 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter) const override {
749 // If numGroup is not present pick 0 as a conservative correct value.
750 int32_t numGroups = adaptor.getNumGroups().value_or(0);
751 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
752 rewriter.eraseOp(op);
753 return success();
754 }
755};
756
757/// Creates mbarrier object in shared memory
758struct NVGPUMBarrierCreateLowering
759 : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
760 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
761
762 template <typename moduleT>
763 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
764 Operation *funcOp, moduleT moduleOp,
765 MemRefType barrierType) const {
766 SymbolTable symbolTable(moduleOp);
767 OpBuilder::InsertionGuard guard(rewriter);
768 rewriter.setInsertionPoint(&moduleOp.front());
769 auto global = memref::GlobalOp::create(
770 rewriter, funcOp->getLoc(), "__mbarrier",
771 /*sym_visibility=*/rewriter.getStringAttr("private"),
772 /*type=*/barrierType,
773 /*initial_value=*/ElementsAttr(),
774 /*constant=*/false,
775 /*alignment=*/rewriter.getI64IntegerAttr(8));
776 symbolTable.insert(global);
777 return global;
778 }
779
780 LogicalResult
781 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
782 ConversionPatternRewriter &rewriter) const override {
783 Operation *funcOp = op->getParentOp();
784 MemRefType barrierType = nvgpu::getMBarrierMemrefType(
785 rewriter.getContext(), op.getBarriers().getType());
786
787 memref::GlobalOp global;
788 if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
789 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
790 else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
791 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
792
793 rewriter.setInsertionPoint(op);
794 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
795 global.getName());
796 return success();
797 }
798};
799
800/// Base class for lowering mbarrier operations to nvvm intrinsics.
801template <typename SourceOp>
802struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
803public:
804 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
805 /// Returns the base pointer of the mbarrier object.
806 Value getMbarrierPtr(ImplicitLocOpBuilder &b,
807 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
808 Value mbarId,
809 ConversionPatternRewriter &rewriter) const {
810 MemRefType mbarrierMemrefType =
811 nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
813 rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
814 }
815};
816
817struct NVGPUMBarrierGetLowering
818 : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
819 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
820
821 LogicalResult
822 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
823 ConversionPatternRewriter &rewriter) const override {
824 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
825 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
826 rewriter.setInsertionPoint(op);
827 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
828 adaptor.getMbarId(), rewriter);
829 Type resType = op.getMbarrierPointer().getType();
830 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
831 return success();
832 }
833};
834
835/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
836struct NVGPUMBarrierInitLowering
837 : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
838 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
839
840 LogicalResult
841 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
842 ConversionPatternRewriter &rewriter) const override {
843 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
844 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
845 rewriter.setInsertionPoint(op);
846 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
847 adaptor.getMbarId(), rewriter);
848 Value count = truncToI32(b, adaptor.getCount());
849 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
850 adaptor.getPredicate());
851 return success();
852 }
853};
854
855/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
856struct NVGPUMBarrierArriveLowering
857 : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
858 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
859 LogicalResult
860 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter) const override {
862 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
863 Value barrier =
864 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
865 adaptor.getMbarId(), rewriter);
866 Type tokenType = getTypeConverter()->convertType(
867 nvgpu::MBarrierTokenType::get(op->getContext()));
868 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
869 return success();
870 }
871};
872
873/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
874/// `nvvm.mbarrier.arrive.nocomplete`
875struct NVGPUMBarrierArriveNoCompleteLowering
876 : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
877 using MBarrierBasePattern<
878 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
879 LogicalResult
880 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter) const override {
882 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
883 Value barrier =
884 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
885 adaptor.getMbarId(), rewriter);
886 Type tokenType = getTypeConverter()->convertType(
887 nvgpu::MBarrierTokenType::get(op->getContext()));
888 Value count = truncToI32(b, adaptor.getCount());
889 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
890 op, tokenType, barrier, count);
891 return success();
892 }
893};
894
895/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
896struct NVGPUMBarrierTestWaitLowering
897 : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
898 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
899 LogicalResult
900 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter) const override {
902 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
903 Value barrier =
904 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
905 adaptor.getMbarId(), rewriter);
906 Type retType = rewriter.getI1Type();
907 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
908 adaptor.getToken());
909 return success();
910 }
911};
912
913struct NVGPUMBarrierArriveExpectTxLowering
914 : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
915 using MBarrierBasePattern<
916 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
917 LogicalResult
918 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter) const override {
920 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
921 Value barrier =
922 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
923 adaptor.getMbarId(), rewriter);
924 Value txcount = truncToI32(b, adaptor.getTxcount());
925 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
926 op, barrier, txcount, adaptor.getPredicate());
927 return success();
928 }
929};
930
931struct NVGPUMBarrierTryWaitParityLowering
932 : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
933 using MBarrierBasePattern<
934 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
935 LogicalResult
936 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
937 ConversionPatternRewriter &rewriter) const override {
938 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
939 Value barrier =
940 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
941 adaptor.getMbarId(), rewriter);
942 Value ticks = truncToI32(b, adaptor.getTicks());
943 Value phase =
944 LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
945 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
946 phase, ticks);
947 return success();
948 }
949};
950
951struct NVGPUTmaAsyncLoadOpLowering
952 : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
953 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
954 LogicalResult
955 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
956 ConversionPatternRewriter &rewriter) const override {
957 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
958 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
959 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
960 adaptor.getDst(), {});
961 // Intrinsics takes a shared-cluster pointer so we need an
962 // address space cast from 3 to 7.
963 // TODO: Introduce AS(7) in NVGPU.
964 auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
965 op->getContext(),
966 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
967 dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);
968
969 Value barrier =
970 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
971 adaptor.getMbarId(), rewriter);
972
973 SmallVector<Value> coords = adaptor.getCoordinates();
974 for (auto [index, value] : llvm::enumerate(coords)) {
975 coords[index] = truncToI32(b, value);
976 }
977
978 // TODO: Enhance the NVGPU Op for other modes too
979 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
980 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
981 ValueRange{}, adaptor.getMulticastMask(), Value{},
982 NVVM::TMALoadMode::TILE, // default is TILE mode
983 false, // default is cluster-scope
984 nullptr, // default is no cta-group
985 adaptor.getPredicate());
986 return success();
987 }
988};
989
990struct NVGPUTmaAsyncStoreOpLowering
991 : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
992 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
993 LogicalResult
994 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
995 ConversionPatternRewriter &rewriter) const override {
996 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
997 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
998 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
999 adaptor.getSrc(), {});
1000 SmallVector<Value> coords = adaptor.getCoordinates();
1001 for (auto [index, value] : llvm::enumerate(coords)) {
1002 coords[index] = truncToI32(b, value);
1003 }
1004
1005 // TODO: Enhance the NVGPU Op for other modes too
1006 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1007 op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1008 NVVM::TMAStoreMode::TILE, // default is TILE mode
1009 adaptor.getPredicate());
1010 return success();
1011 }
1012};
1013
1014struct NVGPUGenerateWarpgroupDescriptorLowering
1015 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1016 using ConvertOpToLLVMPattern<
1017 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1018
1019 LogicalResult
1020 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter) const override {
1022
1023 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1024
1025 nvgpu::TensorMapSwizzleKind swizzleKind =
1026 op.getTensorMap().getType().getSwizzle();
1027
1028 unsigned layout =
1029 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1030 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1031 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1032 : 1;
1033 unsigned swizzle =
1034 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1035 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1036 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1037 : 0;
1038
1039 auto ti64 = b.getIntegerType(64);
1040 auto makeConst = [&](uint64_t index) -> Value {
1041 return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
1042 };
1043 auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1044 return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1045 };
1046 auto shiftRight = [&](Value value, unsigned shift) -> Value {
1047 return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1048 };
1049 auto insertBit = [&](Value desc, Value val, int startBit) {
1050 return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1051 };
1052
1053 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1054 uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1055 uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1056 uint64_t offsetVal = 0;
1057
1058 Value strideDim = makeConst(strideDimVal);
1059 Value leadDim = makeConst(leadDimVal);
1060
1061 Value baseAddr = getStridedElementPtr(
1062 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1063 adaptor.getTensor(), {});
1064 Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1065 // Just use 14 bits for base address
1066 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1067
1068 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1069 startLeadBit = 16, startBaseAddrBit = 0;
1070 Value dsc = makeConst(0);
1071 // // [62,64) swizzle type
1072 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1073 // // [49,52) base_offset
1074 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1075 // // [32,46) stride
1076 dsc = insertBit(dsc, strideDim, startStrideBit);
1077 // // [16,30) leading dimension
1078 dsc = insertBit(dsc, leadDim, startLeadBit);
1079 // // [0,14) start_address
1080 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1081
1082 LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
1083 << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
1084 << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
1085 << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1086 << ")\n start_addr : " << baseAddr;
1087
1088 rewriter.replaceOp(op, dsc);
1089 return success();
1090 }
1091};
1092
1093static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1094 return LLVM::ConstantOp::create(b, b.getIntegerType(64),
1095 b.getI32IntegerAttr(index));
1096}
1097
1098/// Returns a Value that holds data type enum that is expected by CUDA driver.
1099static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1100 // Enum is from CUDA driver API
1101 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1102 enum CUtensorMapDataTypeEnum {
1103 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1104 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1105 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1106 CU_TENSOR_MAP_DATA_TYPE_INT32,
1107 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1108 CU_TENSOR_MAP_DATA_TYPE_INT64,
1109 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1110 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1111 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1112 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1113 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1114 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1115 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1116 };
1117
1118 if (type.isUnsignedInteger(8))
1119 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1120 if (type.isUnsignedInteger(16))
1121 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1122 if (type.isUnsignedInteger(32))
1123 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1124 if (type.isUnsignedInteger(64))
1125 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1126 if (type.isSignlessInteger(32))
1127 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1128 if (type.isSignlessInteger(64))
1129 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1130 if (type.isF16())
1131 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1132 if (type.isF32())
1133 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1134 if (type.isF64())
1135 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1136 if (type.isBF16())
1137 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1138
1139 llvm_unreachable("Not supported data type");
1140}
1141
1142struct NVGPUTmaCreateDescriptorOpLowering
1143 : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1144 using ConvertOpToLLVMPattern<
1145 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1146 LogicalResult
1147 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter) const override {
1149 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1150 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1151 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1152
1153 Value tensorElementType =
1154 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1155 auto promotedOperands = getTypeConverter()->promoteOperands(
1156 b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1157
1158 Value boxArrayPtr = LLVM::AllocaOp::create(
1159 b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1160 for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1161 Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1162 boxArrayPtr, makeI64Const(b, index));
1163 LLVM::StoreOp::create(b, value, gep);
1164 }
1165
1166 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1167 // Set Arguments for the function call
1168 SmallVector<Value> arguments;
1169 arguments.push_back(promotedOperands[0]); // rank
1170 arguments.push_back(promotedOperands[1]); // descriptor
1171 arguments.push_back(tensorElementType); // data type
1172 arguments.push_back(
1173 makeI64Const(b, (int)desc.getInterleave())); // interleave
1174 arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1175 arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1176 arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1177 arguments.push_back(boxArrayPtr); // box dimensions
1178
1179 // Set data types of the arguments
1180 SmallVector<Type> argTypes = {
1181 llvmInt64Type, /* int64_t tensorRank */
1182 llvmPointerType, /* ptr */
1183 llvmInt64Type, /* int64_t */
1184 llvmInt64Type, /* int64_t */
1185 llvmInt64Type, /* int64_t */
1186 llvmInt64Type, /* int64_t */
1187 llvmInt64Type, /* int64_t */
1188 llvmPointerType /* ptr */
1189 };
1190 FunctionCallBuilder hostRegisterCallBuilder = {
1191 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1192 Value tensorMap =
1193 hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1194
1195 rewriter.replaceOp(op, tensorMap);
1196 return success();
1197 }
1198};
1199
1200struct NVGPUWarpgroupMmaOpLowering
1201 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1202 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1203
1204 /// This is a helper class to generate required NVVM Ops for warp-group level
1205 /// matrix multiplication.
1206 /// When the given GEMM shape is larger than the shape of
1207 /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1208 /// Op(s), group and execute them asynchronously. The class also handles
1209 /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1210 /// create descriptors for each instruction.
1211 ///
1212 /// For example this is the case when the shape of GEMM is 128x128x128
1213 ///
1214 /// nvvm.wgmma.fence.aligned
1215 ///
1216 /// nvvm.wgmma.mma.async descA, descB
1217 /// iterate(descA, descB)
1218 /// nvvm.wgmma.mma.async descA, descB
1219 /// [6x times more]
1220 ///
1221 /// nvvm.wgmma.group.sync.aligned
1222 /// nvvm.wgmma.wait.group.sync [groupId]
1223 ///
1224 class WarpgroupGemm {
1225 nvgpu::WarpgroupMmaOp op;
1226 ImplicitLocOpBuilder b;
1227 OpAdaptor adaptor;
1228
1229 // Entire shape of the given Op
1230 int64_t totalM, totalN, totalK;
1231
1232 // Shape of one wgmma instruction
1233 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1234
1235 // Iteration counts for GEMM
1236 int iterationM = 0, iterationN = 0, iterationK = 0;
1237
1238 /// The function returns the shape of wgmma instruction that is defined in
1239 /// PTX programming guide.
1240 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1241 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1242 wgmmaM = 64;
1243 wgmmaN = sizeN;
1244 if (inputElemType.isTF32()) {
1245 wgmmaK = 8;
1246 } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1247 wgmmaK = 16;
1248 } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1249 inputElemType.isInteger(16)) {
1250 wgmmaK = 32;
1251 } else if (inputElemType.isInteger(1)) {
1252 wgmmaK = 256;
1253 } else {
1254 llvm_unreachable("msg: not supported K shape");
1255 }
1256 LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1257 << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
1258 }
1259
1260 /// Generates WGMMATypesAttr from MLIR Type
1261 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1262 bool useF32 = false) const {
1263 auto getWgmmaType = [=](Type elemType) {
1264 if (elemType.isF32() || elemType.isTF32())
1265 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1266 if (elemType.isF16())
1267 return NVVM::WGMMATypes::f16;
1268 if (elemType.isBF16())
1269 return NVVM::WGMMATypes::bf16;
1270 if (isa<Float8E4M3FNType>(elemType))
1271 return NVVM::WGMMATypes::e4m3;
1272 if (isa<Float8E5M2Type>(elemType))
1273 return NVVM::WGMMATypes::e5m2;
1274 if (elemType.isInteger(1))
1275 return NVVM::WGMMATypes::b1;
1276 if (elemType.isInteger(8))
1277 return NVVM::WGMMATypes::s8;
1278 if (elemType.isUnsignedInteger(8))
1279 return NVVM::WGMMATypes::u8;
1280 if (elemType.isInteger(32))
1281 return NVVM::WGMMATypes::s32;
1282 llvm_unreachable("unsupported type");
1283 };
1284 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1285 }
1286
1287 /// Generates layout attribute for the input matrix for wgmma instruction
1288 NVVM::MMALayoutAttr
1289 generateWgmmaLayout(std::optional<bool> transpose) const {
1290 if (transpose.value_or(false))
1291 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1292 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1293 }
1294
1295 /// Generates shape attribute for wgmma instruction
1296 NVVM::MMAShapeAttr generateWgmmaShape() const {
1297 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1298 }
1299
1300 /// Generates scale attributes of output matrix for wgmma instruction
1301 NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1302 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1303 NVVM::WGMMAScaleOut::one);
1304 }
1305 /// Generates scale attributes of input matrix for wgmma instruction
1306 NVVM::WGMMAScaleInAttr generateScaleIn() const {
1307 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1308 NVVM::WGMMAScaleIn::one);
1309 }
1310
1311 /// Basic function to generate Add
1312 Value makeAdd(Value lhs, Value rhs) {
1313 return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1314 };
1315
1316 /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1317 /// Currently, it only handles row-major.
1318 ///
1319 /// It moves the pointer like below for [128][64] size:
1320 /// +2 +4 +6
1321 /// ↓ ↓ ↓
1322 /// descA ---> +--+--+--+--+
1323 /// |->|->|->|->|
1324 /// | | | | |
1325 /// | | | | |
1326 /// | | | | |
1327 /// descA+512---> +-----------+
1328 /// | | | | |
1329 /// | | | | |
1330 /// | | | | |
1331 /// | | | | |
1332 /// +-----------+
1333 ///
1334 Value iterateDescriptorA(Value desc, int i, int j, int k) {
1335 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1336 Type elemA = matrixTypeA.getElementType();
1337 int byte = elemA.getIntOrFloatBitWidth() / 8;
1338 int tileShapeA = matrixTypeA.getDimSize(1);
1339 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1340 incrementVal = incrementVal >> exclude4LSB;
1341 LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
1342 << "] [wgmma descriptors] Descriptor A + " << incrementVal
1343 << " | \t ";
1344 if (!incrementVal)
1345 return desc;
1346 return makeAdd(desc, makeI64Const(b, incrementVal));
1347 }
1348
1349 /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1350 /// Currently, it only handles column-major.
1351 ///
1352 /// It moves the pointer like below for [128][64] size:
1353 /// descB ---> +--+--+--+--+--+--+--+--+
1354 /// |↓ | | | | | | | |
1355 /// |↓ | | | | | | | |
1356 /// |↓ | | | | | | | |
1357 /// |↓ | | | | | | | |
1358 /// +--+--+--+--+--+--+--+--+
1359 ///
1360 Value iterateDescriptorB(Value desc, int i, int j, int k) {
1361 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1362 Type elemB = matrixTypeB.getElementType();
1363 int byte = elemB.getIntOrFloatBitWidth() / 8;
1364 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1365 incrementVal = incrementVal >> exclude4LSB;
1366 LDBG() << "Descriptor B + " << incrementVal;
1367 if (!incrementVal)
1368 return desc;
1369 return makeAdd(desc, makeI64Const(b, incrementVal));
1370 }
1371
1372 /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1373 /// descriptors and arranges them based on induction variables: i, j, and k.
1374 Value generateWgmma(int i, int j, int k, Value matrixC) {
1375 LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1376 << "(A[" << (iterationM * wgmmaM) << ":"
1377 << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
1378 << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
1379 << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
1380 << "][" << 0 << ":" << wgmmaN << "])";
1381
1382 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1383 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1384
1385 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1386 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1387
1388 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1389 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1390
1391 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1392 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1393
1394 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1395 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1396 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1397 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1398 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1399
1400 auto overflow = NVVM::MMAIntOverflowAttr::get(
1401 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1402
1403 return NVVM::WgmmaMmaAsyncOp::create(
1404 b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
1405 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1406 overflow);
1407 }
1408
1409 /// Generates multiple wgmma instructions to complete the given GEMM shape
1410 Value generateWgmmaGroup() {
1411 Value wgmmaResult =
1412 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1413
1414 // Perform GEMM
1415 SmallVector<Value> wgmmaResults;
1416 for (int i = 0; i < iterationM; ++i) {
1417 Value matrixC =
1418 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1419 for (int j = 0; j < iterationN; ++j)
1420 for (int k = 0; k < iterationK; ++k)
1421 matrixC = generateWgmma(i, j, k, matrixC);
1422 wgmmaResults.push_back(matrixC);
1423 }
1424 for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1425 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
1426 wgmmaResult, matrix, idx);
1427 }
1428 return wgmmaResult;
1429 }
1430
1431 public:
1432 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1433 OpAdaptor adaptor)
1434 : op(op), b(b), adaptor(adaptor) {
1435 // Find the entire GEMM Shape
1436 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1437 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1438 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1439 LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
1440 << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
1441 << "] ---===";
1442
1443 // Find the shape for one wgmma instruction
1444 findWgmmaShape(
1445 totalM, totalN,
1446 op.getDescriptorA().getType().getTensor().getElementType());
1447
1448 // Iterations counts to complete the given shape with wgmma shape
1449 iterationM = totalM / wgmmaM;
1450 iterationN = totalN / wgmmaN;
1451 iterationK = totalK / wgmmaK;
1452 }
1453
1454 /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1455 /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1456 /// instructions and group synchronization, as well as waiting
1457 /// (WgmmaGroupSyncAlignedOp) for group synchronization
1458 /// (WgmmaWaitGroupSyncOp) after the instructions.
1459 Value generateWarpgroupMma() {
1460 NVVM::WgmmaFenceAlignedOp::create(b);
1461 Value wgmmaResult = generateWgmmaGroup();
1462 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1463 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1464 return wgmmaResult;
1465 }
1466 };
1467 LogicalResult
1468 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1469 ConversionPatternRewriter &rewriter) const override {
1470 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1471
1472 // Step 1. Build a helper class
1473 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1474
1475 // Step 2. Get the entire GEMM Shape
1476 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1477
1478 // Step 3. Replace fragmented result struct with the op results
1479 rewriter.replaceOp(op, wgmmaResult);
1480 return success();
1481 }
1482};
1483
1484struct NVGPUWarpgroupMmaStoreOpLowering
1485 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1486 using ConvertOpToLLVMPattern<
1487 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1488
1489 /// This function stores a fragmented register matrix owned by a warp group
1490 /// (128 threads) into a memref. Each thread has 64 registers, each the size
1491 /// of a struct.
1492 /// Here is what each threads (T) holds, each `d` is struct value with a
1493 /// number.
1494 ///
1495 /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1496 /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1497 /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1498 /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1499 /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1500 ///
1501 /// Matrix-D:
1502 /// +______________________________________________________________________+
1503 /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1504 /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1505 /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1506 /// ..| .........|.........|.........|.........|........|...........|........|
1507 /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1508 /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1509 /// ..| .........|.........|.........|.........|........|...........|........|
1510 /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1511 /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1512 /// ..| .........|.........|.........|.........|........|...........|........|
1513 /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1514 /// ..| .........|.........|.........|.........|........|...........|........|
1515 /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1516 /// ..| .........|.........|.........|.........|........|...........|........|
1517 /// +______________________________________________________________________+
1518 ///
1519 /// \param rewriter: The pattern rewriter.
1520 /// \param matrixD: Result of the warp-group MMA operation (fragmented
1521 /// matrix). It is holded by a thread and a struct with 64 elements.
1522 /// \param dstMemref: The memref where the registers will be stored.
1523 /// \param offset: the offset within the memref where the registers will be
1524 /// stored.
1525 void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1526 TypedValue<MemRefType> dstMemref,
1527 int offset) const {
1528 Type i32 = b.getI32Type();
1529
1530 auto makeConst = [&](int32_t index) -> Value {
1531 return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
1532 };
1533 Value c1 = makeConst(1);
1534 Value c2 = makeConst(2);
1535 Value c4 = makeConst(4);
1536 Value c8 = makeConst(8);
1537 Value c16 = makeConst(16);
1538 Value warpSize = makeConst(kWarpSize);
1539
1540 auto makeMul = [&](Value lhs, Value rhs) -> Value {
1541 return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
1542 };
1543 auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1544 return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1545 };
1546
1547 auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1549 Type it = b.getIndexType();
1550 Value idx = arith::IndexCastOp::create(b, it, x);
1551 Value idy0 = arith::IndexCastOp::create(b, it, y);
1552 Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1553 Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1554 Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1555 memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
1556 memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
1557 };
1558
1559 Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1560 Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1561 Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1562 Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1563 Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1564
1565 Value tj = makeMul(lane4modId, c2);
1566 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1567 if (offset)
1568 ti = makeAdd(ti, makeConst(offset));
1569
1570 auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1571
1572 // Number of 32-bit registers owns per thread
1573 constexpr unsigned numAdjacentRegisters = 2;
1574 // Number of 8x8 matrices one below another per warp
1575 constexpr unsigned numStackedMatrices = 2;
1576
1577 size_t storeCount = (structType.getBody().size() /
1578 (numStackedMatrices * numAdjacentRegisters));
1579
1580 for (size_t i = 0; i < numStackedMatrices; ++i) {
1581 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1582 for (size_t j = 0; j < storeCount; ++j) {
1583 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1584 size_t structIndex = (i * numAdjacentRegisters) +
1585 (j * (numStackedMatrices * numAdjacentRegisters));
1586 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1587 }
1588 }
1589 }
1590
1591 LogicalResult
1592 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1593 ConversionPatternRewriter &rewriter) const override {
1594 int offset = 0;
1595 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1596 Value matriDValue = adaptor.getMatrixD();
1597 auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1598 for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1599 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1600 Value innerStructValue =
1601 LLVM::ExtractValueOp::create(b, matriDValue, idx);
1602 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1603 offset += structType.getBody().size();
1604 }
1605 rewriter.eraseOp(op);
1606 return success();
1607 }
1608};
1609
1610struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1611 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1612 using ConvertOpToLLVMPattern<
1613 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1614 LogicalResult
1615 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1616 ConversionPatternRewriter &rewriter) const override {
1617 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1618 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1619 getTypeConverter()->convertType(op.getMatrixC().getType()));
1620 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1621 .getBody()
1622 .front();
1623 Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
1624 Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1625 SmallVector<Value> innerStructs;
1626 // Unpack the structs and set all values to zero
1627 for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1628 auto structType = cast<LLVM::LLVMStructType>(s);
1629 Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1630 for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1631 structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1632 zero, ArrayRef<int64_t>({i}));
1633 }
1634 innerStructs.push_back(structValue);
1635 }
1636 // Pack the inner structs into a single struct
1637 for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1638 packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
1639 packStruct, matrix, idx);
1640 }
1641 rewriter.replaceOp(op, packStruct);
1642 return success();
1643 }
1644};
1645
1646struct NVGPUTmaFenceOpLowering
1647 : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1648 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1649 LogicalResult
1650 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1651 ConversionPatternRewriter &rewriter) const override {
1652 MLIRContext *ctx = op.getContext();
1653 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1654 auto i32Ty = b.getI32Type();
1655 Value tensormapSize =
1656 LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
1657
1658 auto memscope =
1659 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1660
1661 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1662 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1663
1664 return success();
1665 }
1666};
1667
1668struct NVGPUTmaPrefetchOpLowering
1669 : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1670 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1671 LogicalResult
1672 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1673 ConversionPatternRewriter &rewriter) const override {
1674 rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1675 op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
1676 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1677 /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
1678 return success();
1679 }
1680};
1681
1682struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1683 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1684 LogicalResult
1685 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1686 ConversionPatternRewriter &rewriter) const override {
1687 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1688 auto i64Ty = b.getI64Type();
1689 auto f32Ty = b.getF32Type();
1690 VectorType inTy = op.getIn().getType();
1691 // apply rcp.approx.ftz.f on each element in vector.
1692 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1693 Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1694 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1695 for (int i = 0; i < numElems; i++) {
1696 Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
1697 Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1698 Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1699 ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1700 }
1701 return ret1DVec;
1702 };
1703 if (inTy.getRank() == 1) {
1704 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1705 return success();
1706 }
1708 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1709 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1710 OpAdaptor adaptor(operands);
1711 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1712 },
1713 rewriter);
1714 }
1715};
1716} // namespace
1717
1719 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1720 patterns.add<
1721 NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1722 NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1723 NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1724 NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1725 NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1726 NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1727 NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1728 NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1729 NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1730 NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1731 NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1732 NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1733 NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1734 NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1735 NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1736 NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1737 NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1738 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1739 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1740 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1741}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr int kWarpSize
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
FloatType getF32Type()
Definition Builders.cpp:43
IntegerType getI32Type()
Definition Builders.cpp:63
FloatType getF16Type()
Definition Builders.cpp:39
MLIRContext * getContext() const
Definition Builders.h:56
FloatType getF64Type()
Definition Builders.cpp:45
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isTF32() const
Definition Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const