MLIR 22.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
26#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34#define DEBUG_TYPE "nvgpu-to-nvvm"
35
36namespace mlir {
37#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38#include "mlir/Conversion/Passes.h.inc"
39} // namespace mlir
40
41using namespace mlir;
42
43/// Number of bits that needs to be excluded when building matrix descriptor for
44/// wgmma operations.
45constexpr int exclude4LSB = 4;
46
47/// GPU has 32 bit registers, this function truncates values when larger width
48/// is not needed.
50 Type type = value.getType();
51 assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
52 if (type.getIntOrFloatBitWidth() <= 32)
53 return value;
54 return LLVM::TruncOp::create(b, b.getI32Type(), value);
55}
56
57/// Returns the type for the intrinsic given the vectorResultType of the
58/// `gpu.mma.sync` operation.
59static Type inferIntrinsicResultType(Type vectorResultType) {
60 MLIRContext *ctx = vectorResultType.getContext();
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63 auto i32Ty = IntegerType::get(ctx, 32);
64 auto i32x2Ty = VectorType::get(2, i32Ty);
65 Type f64Ty = Float64Type::get(ctx);
66 Type f64x2Ty = VectorType::get(2, f64Ty);
67 Type f32Ty = Float32Type::get(ctx);
68 Type f32x2Ty = VectorType::get(2, f32Ty);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
71 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
72 }
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
75 ctx,
76 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
77 }
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
80 }
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
83 ctx,
84 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
85 }
86 if (a.getElementType() == VectorType::get(1, f32Ty)) {
87 return LLVM::LLVMStructType::getLiteral(
88 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
89 }
90 return vectorResultType;
91}
92
93/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
94/// always an LLVM struct) into a fragment that is compatible with the vector
95/// type of this operation. This involves extracting elements from the struct
96/// and inserting them into an LLVM array. These extra data-movement
97/// operations should be canonicalized away by the LLVM backend.
98static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
99 Type resultType, Value intrinsicResult,
100 RewriterBase &rewriter) {
101 MLIRContext *ctx = rewriter.getContext();
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
104 Type i32Ty = rewriter.getI32Type();
105 Type f32Ty = rewriter.getF32Type();
106 Type f64Ty = rewriter.getF64Type();
107 Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
108 Type i32x2Ty = VectorType::get(2, i32Ty);
109 Type f64x2Ty = VectorType::get(2, f64Ty);
110 Type f32x2Ty = VectorType::get(2, f32Ty);
111 Type f32x1Ty = VectorType::get(1, f32Ty);
112
113 auto makeConst = [&](int32_t index) -> Value {
114 return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
115 rewriter.getI32IntegerAttr(index));
116 };
117
118 if (arrayType) {
119 SmallVector<Value, 4> elements;
120
121 // The intrinsic returns 32-bit wide elements in a form which can be
122 // directly bitcasted and inserted into the result vector.
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (unsigned i = 0; i < structType.getBody().size(); i++) {
126 Value el =
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
128 el = rewriter.createOrFold<LLVM::BitcastOp>(
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
131 }
132 }
133
134 // The intrinsic returns i32, f64, and f32 values as individual scalars,
135 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
136 // need to extract them from the struct and pack them into the 64-bit wide
137 // rows of the vector result.
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
141
142 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
143 Value vec =
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
145 Value x1 =
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
148 i * 2 + 1);
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
150 x1, makeConst(0));
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
152 x2, makeConst(1));
153 elements.push_back(vec);
154 }
155 }
156
157 // Create the final vectorized result.
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159 for (const auto &el : llvm::enumerate(elements)) {
160 result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
161 el.index());
162 }
163 return result;
164 }
165
166 return intrinsicResult;
167}
168
169/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
170/// given as 2D `vectors` where the rows are 32b or 64b wide. The
171/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
172/// scalars of certain types. This function helps unpack the `vector` arguments
173/// and cast them to the types expected by `nvvm.mma.sync`.
175 Value operand,
176 NVVM::MMATypes operandPtxType) {
178 Type i32Ty = b.getI32Type();
179 Type f64Ty = b.getF64Type();
180 Type f32Ty = b.getF32Type();
181 Type i64Ty = b.getI64Type();
182 Type i8x4Ty = VectorType::get(4, b.getI8Type());
183 Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
184 Type f32x1Ty = VectorType::get(1, f32Ty);
185 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
186
187 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188 Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
189
190 // For 4xi8 vectors, the intrinsic expects these to be provided as i32
191 // scalar types.
192 if (arrayTy.getElementType() == i8x4Ty ||
193 arrayTy.getElementType() == i4x8Ty ||
194 (arrayTy.getElementType() == f32x1Ty &&
195 operandPtxType == NVVM::MMATypes::tf32)) {
196 result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
197 continue;
198 }
199
200 // For some element types (i32, f32, f64), we need to unpack the inner
201 // vector/array type as well because the intrinsic expects individual
202 // scalars to be provided.
203 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205 innerArrayTy.getElementType() == f64Ty ||
206 innerArrayTy.getElementType() == f32Ty)) {
207 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208 idx < innerSize; idx++) {
209 result.push_back(LLVM::ExtractElementOp::create(
210 b, toUse,
211 LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
212 }
213 continue;
214 }
215 result.push_back(toUse);
216 }
217 return result;
218}
219
220/// Returns whether mbarrier object has shared memory address space.
221static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
222 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223 barrierType.getMemorySpace()));
224}
225
226/// Returns the memory space attribute of the mbarrier object.
228 nvgpu::MBarrierGroupType barrierType) {
229 Attribute memorySpace = {};
230 if (isMbarrierShared(barrierType)) {
231 memorySpace =
232 IntegerAttr::get(IntegerType::get(context, 64),
233 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
234 }
235 return memorySpace;
236}
237
238/// Returns memref type of the mbarrier object. The type is defined in the
239/// MBarrierGroupType.
240MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
241 nvgpu::MBarrierGroupType barrierType) {
242 Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
243 MemRefLayoutAttrInterface layout;
244 return MemRefType::get({barrierType.getNumBarriers()},
245 IntegerType::get(context, 64), layout, memorySpace);
246}
247
248namespace {
249
250struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
251 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
252
253 LogicalResult
254 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter) const override {
256 MLIRContext *ctx = getContext();
257 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
258
259 // The result type of ldmatrix will always be a struct of 32bit integer
260 // registers if more than one 32bit value is returned. Otherwise, the result
261 // is a single i32. The result type of the GPU operation is always a vector
262 // of shape (NumRegisters, VectorRegister) where VectorRegister is the
263 // vector type of the result and always 32 bits long. We bitcast the result
264 // of the NVVM::LdMatrix to this vector type.
265 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266 if (!vectorResultType) {
267 return failure();
268 }
269 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
270 vectorResultType.getElementType());
271
272 int64_t num32BitRegs = vectorResultType.getDimSize(0);
273
274 Type ldMatrixResultType;
275 if (num32BitRegs > 1) {
276 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
277 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
278 } else {
279 ldMatrixResultType = rewriter.getI32Type();
280 }
281
282 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
283 Value srcPtr =
284 getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
285 adaptor.getSrcMemref(), adaptor.getIndices());
286 auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
289 /*num=*/op.getNumTiles(),
290 /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row,
292 /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
293
294 // The ldmatrix operation returns either a single i32 value or a struct of
295 // i32 values. Here we unpack those values and cast them back to their
296 // actual vector type (still of width 32b) and repack them into a result
297 // struct.
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = LLVM::PoisonOp::create(b, finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301 Value i32Register =
302 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
303 : ldMatrixResult;
304 Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
305 result = LLVM::InsertValueOp::create(b, result, casted, i);
306 }
307
308 rewriter.replaceOp(op, result);
309 return success();
310 }
311};
312
313/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314/// enum).
315static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316 Type elType = getElementTypeOrSelf(t);
317 if (elType.isInteger(8))
318 return NVVM::MMATypes::s8;
319 if (elType.isInteger(4))
320 return NVVM::MMATypes::s4;
321 if (elType.isF16())
322 return NVVM::MMATypes::f16;
323 if (elType.isF64())
324 return NVVM::MMATypes::f64;
325 if (elType.isF32())
326 return NVVM::MMATypes::tf32;
327 return failure();
328}
329
330struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
332
333 LogicalResult
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override {
336 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337 // Get the shapes of the MMAMatrix type being used. The shapes will
338 // choose which intrinsic this op will be lowered to.
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
342
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344
345 // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
348 return failure();
349
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351 if (failed(ptxTypeA))
352 return op->emitOpError("failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354 if (failed(ptxTypeB))
355 return op->emitOpError("failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358 /*isAccumulator=*/true);
359 if (!ptxTypeC)
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
362
363 // TODO: add an attribute to the op to customize this behavior.
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
367
368 SmallVector<Value> matA =
369 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370 SmallVector<Value> matB =
371 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372 SmallVector<Value> matC =
373 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376 Type intrinsicResTy = inferIntrinsicResultType(
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult =
379 NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
380 /*shape=*/gemmShape,
381 /*b1Op=*/std::nullopt,
382 /*intOverflow=*/overflow,
383 /*multiplicandPtxTypes=*/
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385 /*multiplicandLayouts=*/
386 std::array<NVVM::MMALayout, 2>{
387 NVVM::MMALayout::row, NVVM::MMALayout::col});
388 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389 desiredRetTy, intrinsicResult,
390 rewriter));
391 return success();
392 }
393};
394
395struct ConvertNVGPUToNVVMPass
396 : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397 using Base::Base;
398
399 void runOnOperation() override {
400 LowerToLLVMOptions options(&getContext());
401 RewritePatternSet patterns(&getContext());
402 LLVMTypeConverter converter(&getContext(), options);
403 IRRewriter rewriter(&getContext());
405
406 /// device-side async tokens cannot be materialized in nvvm. We just
407 /// convert them to a dummy i32 type in order to easily drop them during
408 /// conversion.
409 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
410 return converter.convertType(IntegerType::get(type.getContext(), 32));
411 });
412 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
413 Type elemType = type.getFragmented().getElementType();
414 int64_t sizeM = type.getFragmented().getDimSize(0);
415 int64_t sizeN = type.getFragmented().getDimSize(1);
416
417 unsigned numMembers;
418 if (elemType.isF32() || elemType.isInteger(32))
419 numMembers = sizeN / 2;
420 else if (elemType.isF16())
421 numMembers = sizeN / 4;
422 else
423 llvm_unreachable("unsupported type for warpgroup accumulator");
424
425 SmallVector<Type> innerStructBody;
426 for (unsigned i = 0; i < numMembers; i++)
427 innerStructBody.push_back(elemType);
428 auto innerStructType =
429 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
430
431 SmallVector<Type> structBody;
432 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
433 structBody.push_back(innerStructType);
434
435 auto convertedType =
436 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
437 return converter.convertType(convertedType);
438 });
439 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
440 return converter.convertType(IntegerType::get(type.getContext(), 64));
441 });
442 converter.addConversion(
443 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
444 return converter.convertType(IntegerType::get(type.getContext(), 64));
445 });
446 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
447 return converter.convertType(
448 nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
449 });
450 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
451 return LLVM::LLVMPointerType::get(type.getContext());
452 });
454 LLVMConversionTarget target(getContext());
455 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
456 target.addLegalDialect<::mlir::arith::ArithDialect>();
457 target.addLegalDialect<::mlir::memref::MemRefDialect>();
458 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
460 converter, patterns, target);
461 if (failed(applyPartialConversion(getOperation(), target,
462 std::move(patterns))))
463 signalPassFailure();
464 }
465};
466
467/// Returns the constraints for the sparse MMA inline assembly instruction.
468static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
469 unsigned matBSize,
470 unsigned matCSize) {
471 std::string str;
472 llvm::raw_string_ostream ss(str);
473 for (unsigned i = 0; i < matCSize; i++)
474 ss << "=r,";
475 for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
476 ss << "r,";
477 // The final operand is for the sparsity metadata.
478 // The sparsity selector appears as direct literal.
479 ss << "r";
480 return str;
481}
482
483/// Returns the string for the `mma.sp.sync` instruction that corresponds to
484/// the given parameters. Note that this function doesn't do any validation,
485/// it's expected that the provided parameters correspond to a valid
486/// instruction.
487static std::string buildMmaSparseAsmString(
488 const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
489 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
490 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
491 std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
492 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
493 return NVVM::stringifyMMATypes(ptxType);
494 };
495
496 std::string asmStr;
497 llvm::raw_string_ostream ss(asmStr);
498 ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
499 << shape[2] << ".row.col.";
500
501 if (overflow)
502 ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
503
504 ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
505 << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
506 unsigned asmArgIdx = 0;
507
508 // The operand string is structured into sections `{matC elements...},
509 // {matA elements...}, {matB elements...}, {matC elements}`.
510 for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
511 ss << "{";
512 for (unsigned i = 0; i < arrSize; i++)
513 ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
514 ss << "},";
515 }
516 ss << "$" << asmArgIdx++ << ",";
517 assert(metaDataSelector <= 1);
518 ss << "0x" << metaDataSelector << ";";
519 return asmStr;
520}
521
522/// Builds an inline assembly operation corresponding to the specified MMA
523/// sparse sync operation.
524static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
525 ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
526 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
527 std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
528 ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
529 int64_t metadataSelector, const std::array<int64_t, 3> &shape,
530 Type intrinsicResultType) {
531 auto asmDialectAttr =
532 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
533
534 const unsigned matASize = unpackedAData.size();
535 const unsigned matBSize = unpackedB.size();
536 const unsigned matCSize = unpackedC.size();
537
538 std::string asmStr = buildMmaSparseAsmString(
539 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
540 ptxTypeD, overflow, metadataSelector);
541 std::string constraintStr =
542 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
543
544 SmallVector<Value> asmVals;
545 asmVals.reserve(matASize + matBSize + matCSize + 1);
546 for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
547 llvm::append_range(asmVals, args);
548 asmVals.push_back(indexData);
549
550 return LLVM::InlineAsmOp::create(b,
551 /*resultTypes=*/intrinsicResultType,
552 /*operands=*/asmVals,
553 /*asm_string=*/asmStr,
554 /*constraints=*/constraintStr,
555 /*has_side_effects=*/true,
556 /*is_align_stack=*/false,
557 LLVM::TailCallKind::None,
558 /*asm_dialect=*/asmDialectAttr,
559 /*operand_attrs=*/ArrayAttr());
560}
561
562/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
563struct NVGPUMmaSparseSyncLowering
564 : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
565 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
566
567 LogicalResult
568 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
569 ConversionPatternRewriter &rewriter) const override {
570 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
571 // Get the shapes of the MMAMatrix type being used. The shapes will
572 // choose which intrinsic this op will be lowered to.
573 VectorType aType = op.getMatrixA().getType();
574 VectorType bType = op.getMatrixB().getType();
575 VectorType cType = op.getMatrixC().getType();
576
577 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
578 if (failed(ptxTypeA))
579 return op->emitOpError("failed to deduce operand PTX types");
580 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
581 if (failed(ptxTypeB))
582 return op->emitOpError("failed to deduce operand PTX types");
583 std::optional<NVVM::MMATypes> ptxTypeC =
584 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
585 /*isAccumulator=*/true);
586 if (!ptxTypeC)
587 return op->emitError(
588 "could not infer the PTX type for the accumulator/result");
589
590 // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
591 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
592 if (aType.getElementType().isF32() && !tf32Enabled)
593 return failure();
594
595 // TODO: add an attribute to the op to customize this behavior.
596 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
597 if (isa<IntegerType>(aType.getElementType()))
598 overflow = NVVM::MMAIntOverflow::satfinite;
599
600 SmallVector<Value> matA =
601 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
602 SmallVector<Value> matB =
603 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
604 SmallVector<Value> matC =
605 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
606
607 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
608 Type intrinsicResTy = inferIntrinsicResultType(
609 typeConverter->convertType(op->getResultTypes()[0]));
610
611 // Bitcast the sparse metadata from vector<2xf16> to an i32.
612 Value sparseMetadata = adaptor.getSparseMetadata();
613 if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
614 return op->emitOpError() << "Expected metadata type to be LLVM "
615 "VectorType of 2 i16 elements";
616 sparseMetadata =
617 LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
618
619 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
620 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
621 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
622 intrinsicResTy);
623 if (failed(intrinsicResult))
624 return failure();
625
626 assert((*intrinsicResult).getNumResults() == 1 &&
627 "expected inline asm op returns a single LLVM struct type");
628 rewriter.replaceOp(
629 op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
630 (*intrinsicResult)->getResult(0), rewriter));
631 return success();
632 }
633};
634
635struct NVGPUAsyncCopyLowering
636 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
637 using ConvertOpToLLVMPattern<
638 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
639
640 LogicalResult
641 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter) const override {
643 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
644 Location loc = op.getLoc();
645 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
646 Value dstPtr =
647 getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
648 adaptor.getDst(), adaptor.getDstIndices());
649 FailureOr<unsigned> dstAddressSpace =
650 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
651 if (failed(dstAddressSpace))
652 return rewriter.notifyMatchFailure(
653 loc, "destination memref address space not convertible to integer");
654
655 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
656 FailureOr<unsigned> srcAddressSpace =
657 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
658 if (failed(srcAddressSpace))
659 return rewriter.notifyMatchFailure(
660 loc, "source memref address space not convertible to integer");
661
662 Value scrPtr =
663 getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
664 adaptor.getSrcIndices());
665 // Intrinsics takes a global pointer so we need an address space cast.
666 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
667 op->getContext(), static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
668 scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
669 int64_t dstElements = adaptor.getDstElements().getZExtValue();
670 int64_t sizeInBytes =
671 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
672 // When the optional SrcElements argument is *not* present, the regular
673 // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
674 // memory) to fill DstElements number of elements in the destination
675 // (shared memory).
676 Value srcBytes = adaptor.getSrcElements();
677 if (srcBytes) {
678 // When the optional SrcElements argument is present, the source (global
679 // memory) of CpAsyncOp is read only for SrcElements number of elements.
680 // The rest of the DstElements in the destination (shared memory) are
681 // filled with zeros.
682 Value c3I32 =
683 LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
684 Value bitwidth = LLVM::ConstantOp::create(
685 b, b.getI32Type(),
686 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
687 Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
688 srcBytes = LLVM::LShrOp::create(
689 b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
690 }
691 // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
692 // 16 dst bytes.
693 NVVM::LoadCacheModifierKind cacheModifier =
694 (op.getBypassL1().value_or(false) && sizeInBytes == 16)
695 ? NVVM::LoadCacheModifierKind::CG
696 : NVVM::LoadCacheModifierKind::CA;
697
698 NVVM::CpAsyncOp::create(
699 b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
700 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
701 srcBytes);
702
703 // Drop the result token.
704 Value zero =
705 LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
706 rewriter.getI32IntegerAttr(0));
707 rewriter.replaceOp(op, zero);
708 return success();
709 }
710};
711
712struct NVGPUAsyncCreateGroupLowering
713 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
714 using ConvertOpToLLVMPattern<
715 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
716
717 LogicalResult
718 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
719 ConversionPatternRewriter &rewriter) const override {
720 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
721 // Drop the result token.
722 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
723 IntegerType::get(op.getContext(), 32),
724 rewriter.getI32IntegerAttr(0));
725 rewriter.replaceOp(op, zero);
726 return success();
727 }
728};
729
730struct NVGPUAsyncWaitLowering
731 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
732 using ConvertOpToLLVMPattern<
733 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
734
735 LogicalResult
736 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter) const override {
738 // If numGroup is not present pick 0 as a conservative correct value.
739 int32_t numGroups = adaptor.getNumGroups().value_or(0);
740 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
741 rewriter.eraseOp(op);
742 return success();
743 }
744};
745
746/// Creates mbarrier object in shared memory
747struct NVGPUMBarrierCreateLowering
748 : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
749 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
750
751 template <typename moduleT>
752 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
753 Operation *funcOp, moduleT moduleOp,
754 MemRefType barrierType) const {
755 SymbolTable symbolTable(moduleOp);
756 OpBuilder::InsertionGuard guard(rewriter);
757 rewriter.setInsertionPoint(&moduleOp.front());
758 auto global = memref::GlobalOp::create(
759 rewriter, funcOp->getLoc(), "__mbarrier",
760 /*sym_visibility=*/rewriter.getStringAttr("private"),
761 /*type=*/barrierType,
762 /*initial_value=*/ElementsAttr(),
763 /*constant=*/false,
764 /*alignment=*/rewriter.getI64IntegerAttr(8));
765 symbolTable.insert(global);
766 return global;
767 }
768
769 LogicalResult
770 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
771 ConversionPatternRewriter &rewriter) const override {
772 Operation *funcOp = op->getParentOp();
773 MemRefType barrierType = nvgpu::getMBarrierMemrefType(
774 rewriter.getContext(), op.getBarriers().getType());
775
776 memref::GlobalOp global;
777 if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
778 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
779 else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
780 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
781
782 rewriter.setInsertionPoint(op);
783 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
784 global.getName());
785 return success();
786 }
787};
788
789/// Base class for lowering mbarrier operations to nvvm intrinsics.
790template <typename SourceOp>
791struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
792public:
793 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
794 /// Returns the base pointer of the mbarrier object.
795 Value getMbarrierPtr(ImplicitLocOpBuilder &b,
796 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
797 Value mbarId,
798 ConversionPatternRewriter &rewriter) const {
799 MemRefType mbarrierMemrefType =
800 nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
802 rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
803 }
804};
805
806struct NVGPUMBarrierGetLowering
807 : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
808 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
809
810 LogicalResult
811 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
812 ConversionPatternRewriter &rewriter) const override {
813 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
814 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
815 rewriter.setInsertionPoint(op);
816 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
817 adaptor.getMbarId(), rewriter);
818 Type resType = op.getMbarrierPointer().getType();
819 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
820 return success();
821 }
822};
823
824/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
825struct NVGPUMBarrierInitLowering
826 : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
827 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
828
829 LogicalResult
830 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
831 ConversionPatternRewriter &rewriter) const override {
832 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
833 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
834 rewriter.setInsertionPoint(op);
835 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
836 adaptor.getMbarId(), rewriter);
837 Value count = truncToI32(b, adaptor.getCount());
838 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
839 adaptor.getPredicate());
840 return success();
841 }
842};
843
844/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
845struct NVGPUMBarrierArriveLowering
846 : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
847 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
848 LogicalResult
849 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
850 ConversionPatternRewriter &rewriter) const override {
851 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
852 Value barrier =
853 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
854 adaptor.getMbarId(), rewriter);
855 Type tokenType = getTypeConverter()->convertType(
856 nvgpu::MBarrierTokenType::get(op->getContext()));
857 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
858 return success();
859 }
860};
861
862/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
863/// `nvvm.mbarrier.arrive.nocomplete`
864struct NVGPUMBarrierArriveNoCompleteLowering
865 : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
866 using MBarrierBasePattern<
867 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
868 LogicalResult
869 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
870 ConversionPatternRewriter &rewriter) const override {
871 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
872 Value barrier =
873 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
874 adaptor.getMbarId(), rewriter);
875 Type tokenType = getTypeConverter()->convertType(
876 nvgpu::MBarrierTokenType::get(op->getContext()));
877 Value count = truncToI32(b, adaptor.getCount());
878 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
879 op, tokenType, barrier, count);
880 return success();
881 }
882};
883
884/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
885struct NVGPUMBarrierTestWaitLowering
886 : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
887 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
888 LogicalResult
889 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter) const override {
891 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
892 Value barrier =
893 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
894 adaptor.getMbarId(), rewriter);
895 Type retType = rewriter.getI1Type();
896 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
897 adaptor.getToken());
898 return success();
899 }
900};
901
902struct NVGPUMBarrierArriveExpectTxLowering
903 : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
904 using MBarrierBasePattern<
905 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
906 LogicalResult
907 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
908 ConversionPatternRewriter &rewriter) const override {
909 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
910 Value barrier =
911 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
912 adaptor.getMbarId(), rewriter);
913 Value txcount = truncToI32(b, adaptor.getTxcount());
914 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
915 op, Type{}, // return-value is optional and is void by default
916 barrier, txcount, // barrier and txcount
917 NVVM::MemScopeKind::CTA, // default scope is CTA
918 false, // relaxed-semantics is false
919 adaptor.getPredicate());
920 return success();
921 }
922};
923
924struct NVGPUMBarrierTryWaitParityLowering
925 : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
926 using MBarrierBasePattern<
927 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
928 LogicalResult
929 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
930 ConversionPatternRewriter &rewriter) const override {
931 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
932 Value barrier =
933 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
934 adaptor.getMbarId(), rewriter);
935 Value ticks = truncToI32(b, adaptor.getTicks());
936 Value phase =
937 LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
938 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
939 phase, ticks);
940 return success();
941 }
942};
943
944struct NVGPUTmaAsyncLoadOpLowering
945 : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
946 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
947 LogicalResult
948 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
949 ConversionPatternRewriter &rewriter) const override {
950 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
951 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
952 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
953 adaptor.getDst(), {});
954 // Intrinsics takes a shared-cluster pointer so we need an
955 // address space cast from 3 to 7.
956 // TODO: Introduce AS(7) in NVGPU.
957 auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
958 op->getContext(),
959 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
960 dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);
961
962 Value barrier =
963 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
964 adaptor.getMbarId(), rewriter);
965
966 SmallVector<Value> coords = adaptor.getCoordinates();
967 for (auto [index, value] : llvm::enumerate(coords)) {
968 coords[index] = truncToI32(b, value);
969 }
970
971 // TODO: Enhance the NVGPU Op for other modes too
972 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
973 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
974 ValueRange{}, adaptor.getMulticastMask(), Value{},
975 NVVM::TMALoadMode::TILE, // default is TILE mode
976 false, // default is cluster-scope
977 nullptr, // default is no cta-group
978 adaptor.getPredicate());
979 return success();
980 }
981};
982
983struct NVGPUTmaAsyncStoreOpLowering
984 : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
985 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
986 LogicalResult
987 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
988 ConversionPatternRewriter &rewriter) const override {
989 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
990 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
991 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
992 adaptor.getSrc(), {});
993 SmallVector<Value> coords = adaptor.getCoordinates();
994 for (auto [index, value] : llvm::enumerate(coords)) {
995 coords[index] = truncToI32(b, value);
996 }
997
998 // TODO: Enhance the NVGPU Op for other modes too
999 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1000 op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1001 NVVM::TMAStoreMode::TILE, // default is TILE mode
1002 adaptor.getPredicate());
1003 return success();
1004 }
1005};
1006
1007struct NVGPUGenerateWarpgroupDescriptorLowering
1008 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1009 using ConvertOpToLLVMPattern<
1010 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1011
1012 LogicalResult
1013 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1014 ConversionPatternRewriter &rewriter) const override {
1015
1016 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1017
1018 nvgpu::TensorMapSwizzleKind swizzleKind =
1019 op.getTensorMap().getType().getSwizzle();
1020
1021 unsigned layout =
1022 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1023 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1024 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1025 : 1;
1026 unsigned swizzle =
1027 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1028 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1029 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1030 : 0;
1031
1032 auto ti64 = b.getIntegerType(64);
1033 auto makeConst = [&](uint64_t index) -> Value {
1034 return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
1035 };
1036 auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1037 return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1038 };
1039 auto shiftRight = [&](Value value, unsigned shift) -> Value {
1040 return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1041 };
1042 auto insertBit = [&](Value desc, Value val, int startBit) {
1043 return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1044 };
1045
1046 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1047 uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1048 uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1049 uint64_t offsetVal = 0;
1050
1051 Value strideDim = makeConst(strideDimVal);
1052 Value leadDim = makeConst(leadDimVal);
1053
1054 Value baseAddr = getStridedElementPtr(
1055 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1056 adaptor.getTensor(), {});
1057 Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1058 // Just use 14 bits for base address
1059 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1060
1061 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1062 startLeadBit = 16, startBaseAddrBit = 0;
1063 Value dsc = makeConst(0);
1064 // // [62,64) swizzle type
1065 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1066 // // [49,52) base_offset
1067 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1068 // // [32,46) stride
1069 dsc = insertBit(dsc, strideDim, startStrideBit);
1070 // // [16,30) leading dimension
1071 dsc = insertBit(dsc, leadDim, startLeadBit);
1072 // // [0,14) start_address
1073 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1074
1075 LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
1076 << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
1077 << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
1078 << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1079 << ")\n start_addr : " << baseAddr;
1080
1081 rewriter.replaceOp(op, dsc);
1082 return success();
1083 }
1084};
1085
1086static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1087 return LLVM::ConstantOp::create(b, b.getIntegerType(64),
1088 b.getI32IntegerAttr(index));
1089}
1090
1091/// Returns a Value that holds data type enum that is expected by CUDA driver.
1092static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1093 // Enum is from CUDA driver API
1094 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1095 enum CUtensorMapDataTypeEnum {
1096 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1097 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1098 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1099 CU_TENSOR_MAP_DATA_TYPE_INT32,
1100 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1101 CU_TENSOR_MAP_DATA_TYPE_INT64,
1102 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1103 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1104 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1105 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1106 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1107 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1108 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1109 };
1110
1111 if (type.isUnsignedInteger(8))
1112 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1113 if (type.isUnsignedInteger(16))
1114 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1115 if (type.isUnsignedInteger(32))
1116 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1117 if (type.isUnsignedInteger(64))
1118 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1119 if (type.isSignlessInteger(32))
1120 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1121 if (type.isSignlessInteger(64))
1122 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1123 if (type.isF16())
1124 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1125 if (type.isF32())
1126 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1127 if (type.isF64())
1128 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1129 if (type.isBF16())
1130 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1131
1132 llvm_unreachable("Not supported data type");
1133}
1134
1135struct NVGPUTmaCreateDescriptorOpLowering
1136 : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1137 using ConvertOpToLLVMPattern<
1138 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1139 LogicalResult
1140 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1141 ConversionPatternRewriter &rewriter) const override {
1142 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1143 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1144 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1145
1146 Value tensorElementType =
1147 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1148 auto promotedOperands = getTypeConverter()->promoteOperands(
1149 b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1150
1151 Value boxArrayPtr = LLVM::AllocaOp::create(
1152 b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1153 for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1154 Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1155 boxArrayPtr, makeI64Const(b, index));
1156 LLVM::StoreOp::create(b, value, gep);
1157 }
1158
1159 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1160 // Set Arguments for the function call
1161 SmallVector<Value> arguments;
1162 arguments.push_back(promotedOperands[0]); // rank
1163 arguments.push_back(promotedOperands[1]); // descriptor
1164 arguments.push_back(tensorElementType); // data type
1165 arguments.push_back(
1166 makeI64Const(b, (int)desc.getInterleave())); // interleave
1167 arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1168 arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1169 arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1170 arguments.push_back(boxArrayPtr); // box dimensions
1171
1172 // Set data types of the arguments
1173 SmallVector<Type> argTypes = {
1174 llvmInt64Type, /* int64_t tensorRank */
1175 llvmPointerType, /* ptr */
1176 llvmInt64Type, /* int64_t */
1177 llvmInt64Type, /* int64_t */
1178 llvmInt64Type, /* int64_t */
1179 llvmInt64Type, /* int64_t */
1180 llvmInt64Type, /* int64_t */
1181 llvmPointerType /* ptr */
1182 };
1183 FunctionCallBuilder hostRegisterCallBuilder = {
1184 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1185 Value tensorMap =
1186 hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1187
1188 rewriter.replaceOp(op, tensorMap);
1189 return success();
1190 }
1191};
1192
1193struct NVGPUWarpgroupMmaOpLowering
1194 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1195 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1196
1197 /// This is a helper class to generate required NVVM Ops for warp-group level
1198 /// matrix multiplication.
1199 /// When the given GEMM shape is larger than the shape of
1200 /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1201 /// Op(s), group and execute them asynchronously. The class also handles
1202 /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1203 /// create descriptors for each instruction.
1204 ///
1205 /// For example this is the case when the shape of GEMM is 128x128x128
1206 ///
1207 /// nvvm.wgmma.fence.aligned
1208 ///
1209 /// nvvm.wgmma.mma.async descA, descB
1210 /// iterate(descA, descB)
1211 /// nvvm.wgmma.mma.async descA, descB
1212 /// [6x times more]
1213 ///
1214 /// nvvm.wgmma.group.sync.aligned
1215 /// nvvm.wgmma.wait.group.sync [groupId]
1216 ///
1217 class WarpgroupGemm {
1218 nvgpu::WarpgroupMmaOp op;
1219 ImplicitLocOpBuilder b;
1220 OpAdaptor adaptor;
1221
1222 // Entire shape of the given Op
1223 int64_t totalM, totalN, totalK;
1224
1225 // Shape of one wgmma instruction
1226 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1227
1228 // Iteration counts for GEMM
1229 int iterationM = 0, iterationN = 0, iterationK = 0;
1230
1231 /// The function returns the shape of wgmma instruction that is defined in
1232 /// PTX programming guide.
1233 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1234 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1235 wgmmaM = 64;
1236 wgmmaN = sizeN;
1237 if (inputElemType.isTF32()) {
1238 wgmmaK = 8;
1239 } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1240 wgmmaK = 16;
1241 } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1242 inputElemType.isInteger(16)) {
1243 wgmmaK = 32;
1244 } else if (inputElemType.isInteger(1)) {
1245 wgmmaK = 256;
1246 } else {
1247 llvm_unreachable("msg: not supported K shape");
1248 }
1249 LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1250 << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
1251 }
1252
1253 /// Generates WGMMATypesAttr from MLIR Type
1254 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1255 bool useF32 = false) const {
1256 auto getWgmmaType = [=](Type elemType) {
1257 if (elemType.isF32() || elemType.isTF32())
1258 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1259 if (elemType.isF16())
1260 return NVVM::WGMMATypes::f16;
1261 if (elemType.isBF16())
1262 return NVVM::WGMMATypes::bf16;
1263 if (isa<Float8E4M3FNType>(elemType))
1264 return NVVM::WGMMATypes::e4m3;
1265 if (isa<Float8E5M2Type>(elemType))
1266 return NVVM::WGMMATypes::e5m2;
1267 if (elemType.isInteger(1))
1268 return NVVM::WGMMATypes::b1;
1269 if (elemType.isInteger(8))
1270 return NVVM::WGMMATypes::s8;
1271 if (elemType.isUnsignedInteger(8))
1272 return NVVM::WGMMATypes::u8;
1273 if (elemType.isInteger(32))
1274 return NVVM::WGMMATypes::s32;
1275 llvm_unreachable("unsupported type");
1276 };
1277 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1278 }
1279
1280 /// Generates layout attribute for the input matrix for wgmma instruction
1281 NVVM::MMALayoutAttr
1282 generateWgmmaLayout(std::optional<bool> transpose) const {
1283 if (transpose.value_or(false))
1284 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1285 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1286 }
1287
1288 /// Generates shape attribute for wgmma instruction
1289 NVVM::MMAShapeAttr generateWgmmaShape() const {
1290 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1291 }
1292
1293 /// Generates scale attributes of output matrix for wgmma instruction
1294 NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1295 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1296 NVVM::WGMMAScaleOut::one);
1297 }
1298 /// Generates scale attributes of input matrix for wgmma instruction
1299 NVVM::WGMMAScaleInAttr generateScaleIn() const {
1300 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1301 NVVM::WGMMAScaleIn::one);
1302 }
1303
1304 /// Basic function to generate Add
1305 Value makeAdd(Value lhs, Value rhs) {
1306 return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1307 };
1308
1309 /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1310 /// Currently, it only handles row-major.
1311 ///
1312 /// It moves the pointer like below for [128][64] size:
1313 /// +2 +4 +6
1314 /// ↓ ↓ ↓
1315 /// descA ---> +--+--+--+--+
1316 /// |->|->|->|->|
1317 /// | | | | |
1318 /// | | | | |
1319 /// | | | | |
1320 /// descA+512---> +-----------+
1321 /// | | | | |
1322 /// | | | | |
1323 /// | | | | |
1324 /// | | | | |
1325 /// +-----------+
1326 ///
1327 Value iterateDescriptorA(Value desc, int i, int j, int k) {
1328 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1329 Type elemA = matrixTypeA.getElementType();
1330 int byte = elemA.getIntOrFloatBitWidth() / 8;
1331 int tileShapeA = matrixTypeA.getDimSize(1);
1332 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1333 incrementVal = incrementVal >> exclude4LSB;
1334 LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
1335 << "] [wgmma descriptors] Descriptor A + " << incrementVal
1336 << " | \t ";
1337 if (!incrementVal)
1338 return desc;
1339 return makeAdd(desc, makeI64Const(b, incrementVal));
1340 }
1341
1342 /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1343 /// Currently, it only handles column-major.
1344 ///
1345 /// It moves the pointer like below for [128][64] size:
1346 /// descB ---> +--+--+--+--+--+--+--+--+
1347 /// |↓ | | | | | | | |
1348 /// |↓ | | | | | | | |
1349 /// |↓ | | | | | | | |
1350 /// |↓ | | | | | | | |
1351 /// +--+--+--+--+--+--+--+--+
1352 ///
1353 Value iterateDescriptorB(Value desc, int i, int j, int k) {
1354 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1355 Type elemB = matrixTypeB.getElementType();
1356 int byte = elemB.getIntOrFloatBitWidth() / 8;
1357 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1358 incrementVal = incrementVal >> exclude4LSB;
1359 LDBG() << "Descriptor B + " << incrementVal;
1360 if (!incrementVal)
1361 return desc;
1362 return makeAdd(desc, makeI64Const(b, incrementVal));
1363 }
1364
1365 /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1366 /// descriptors and arranges them based on induction variables: i, j, and k.
1367 Value generateWgmma(int i, int j, int k, Value matrixC) {
1368 LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1369 << "(A[" << (iterationM * wgmmaM) << ":"
1370 << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
1371 << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
1372 << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
1373 << "][" << 0 << ":" << wgmmaN << "])";
1374
1375 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1376 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1377
1378 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1379 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1380
1381 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1382 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1383
1384 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1385 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1386
1387 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1388 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1389 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1390 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1391 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1392
1393 auto overflow = NVVM::MMAIntOverflowAttr::get(
1394 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1395
1396 return NVVM::WgmmaMmaAsyncOp::create(
1397 b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
1398 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1399 overflow);
1400 }
1401
1402 /// Generates multiple wgmma instructions to complete the given GEMM shape
1403 Value generateWgmmaGroup() {
1404 Value wgmmaResult =
1405 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1406
1407 // Perform GEMM
1408 SmallVector<Value> wgmmaResults;
1409 for (int i = 0; i < iterationM; ++i) {
1410 Value matrixC =
1411 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1412 for (int j = 0; j < iterationN; ++j)
1413 for (int k = 0; k < iterationK; ++k)
1414 matrixC = generateWgmma(i, j, k, matrixC);
1415 wgmmaResults.push_back(matrixC);
1416 }
1417 for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1418 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
1419 wgmmaResult, matrix, idx);
1420 }
1421 return wgmmaResult;
1422 }
1423
1424 public:
1425 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1426 OpAdaptor adaptor)
1427 : op(op), b(b), adaptor(adaptor) {
1428 // Find the entire GEMM Shape
1429 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1430 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1431 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1432 LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
1433 << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
1434 << "] ---===";
1435
1436 // Find the shape for one wgmma instruction
1437 findWgmmaShape(
1438 totalM, totalN,
1439 op.getDescriptorA().getType().getTensor().getElementType());
1440
1441 // Iterations counts to complete the given shape with wgmma shape
1442 iterationM = totalM / wgmmaM;
1443 iterationN = totalN / wgmmaN;
1444 iterationK = totalK / wgmmaK;
1445 }
1446
1447 /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1448 /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1449 /// instructions and group synchronization, as well as waiting
1450 /// (WgmmaGroupSyncAlignedOp) for group synchronization
1451 /// (WgmmaWaitGroupSyncOp) after the instructions.
1452 Value generateWarpgroupMma() {
1453 NVVM::WgmmaFenceAlignedOp::create(b);
1454 Value wgmmaResult = generateWgmmaGroup();
1455 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1456 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1457 return wgmmaResult;
1458 }
1459 };
1460 LogicalResult
1461 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter) const override {
1463 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1464
1465 // Step 1. Build a helper class
1466 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1467
1468 // Step 2. Get the entire GEMM Shape
1469 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1470
1471 // Step 3. Replace fragmented result struct with the op results
1472 rewriter.replaceOp(op, wgmmaResult);
1473 return success();
1474 }
1475};
1476
1477struct NVGPUWarpgroupMmaStoreOpLowering
1478 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1479 using ConvertOpToLLVMPattern<
1480 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1481
1482 /// This function stores a fragmented register matrix owned by a warp group
1483 /// (128 threads) into a memref. Each thread has 64 registers, each the size
1484 /// of a struct.
1485 /// Here is what each threads (T) holds, each `d` is struct value with a
1486 /// number.
1487 ///
1488 /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1489 /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1490 /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1491 /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1492 /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1493 ///
1494 /// Matrix-D:
1495 /// +______________________________________________________________________+
1496 /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1497 /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1498 /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1499 /// ..| .........|.........|.........|.........|........|...........|........|
1500 /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1501 /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1502 /// ..| .........|.........|.........|.........|........|...........|........|
1503 /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1504 /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1505 /// ..| .........|.........|.........|.........|........|...........|........|
1506 /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1507 /// ..| .........|.........|.........|.........|........|...........|........|
1508 /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1509 /// ..| .........|.........|.........|.........|........|...........|........|
1510 /// +______________________________________________________________________+
1511 ///
1512 /// \param rewriter: The pattern rewriter.
1513 /// \param matrixD: Result of the warp-group MMA operation (fragmented
1514 /// matrix). It is holded by a thread and a struct with 64 elements.
1515 /// \param dstMemref: The memref where the registers will be stored.
1516 /// \param offset: the offset within the memref where the registers will be
1517 /// stored.
1518 void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1519 TypedValue<MemRefType> dstMemref,
1520 int offset) const {
1521 Type i32 = b.getI32Type();
1522
1523 auto makeConst = [&](int32_t index) -> Value {
1524 return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
1525 };
1526 Value c1 = makeConst(1);
1527 Value c2 = makeConst(2);
1528 Value c4 = makeConst(4);
1529 Value c8 = makeConst(8);
1530 Value c16 = makeConst(16);
1531 Value warpSize = makeConst(kWarpSize);
1532
1533 auto makeMul = [&](Value lhs, Value rhs) -> Value {
1534 return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
1535 };
1536 auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1537 return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1538 };
1539
1540 auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1542 Type it = b.getIndexType();
1543 Value idx = arith::IndexCastOp::create(b, it, x);
1544 Value idy0 = arith::IndexCastOp::create(b, it, y);
1545 Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1546 Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1547 Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1548 memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
1549 memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
1550 };
1551
1552 Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1553 Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1554 Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1555 Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1556 Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1557
1558 Value tj = makeMul(lane4modId, c2);
1559 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1560 if (offset)
1561 ti = makeAdd(ti, makeConst(offset));
1562
1563 auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1564
1565 // Number of 32-bit registers owns per thread
1566 constexpr unsigned numAdjacentRegisters = 2;
1567 // Number of 8x8 matrices one below another per warp
1568 constexpr unsigned numStackedMatrices = 2;
1569
1570 size_t storeCount = (structType.getBody().size() /
1571 (numStackedMatrices * numAdjacentRegisters));
1572
1573 for (size_t i = 0; i < numStackedMatrices; ++i) {
1574 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1575 for (size_t j = 0; j < storeCount; ++j) {
1576 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1577 size_t structIndex = (i * numAdjacentRegisters) +
1578 (j * (numStackedMatrices * numAdjacentRegisters));
1579 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1580 }
1581 }
1582 }
1583
1584 LogicalResult
1585 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1586 ConversionPatternRewriter &rewriter) const override {
1587 int offset = 0;
1588 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1589 Value matriDValue = adaptor.getMatrixD();
1590 auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1591 for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1592 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1593 Value innerStructValue =
1594 LLVM::ExtractValueOp::create(b, matriDValue, idx);
1595 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1596 offset += structType.getBody().size();
1597 }
1598 rewriter.eraseOp(op);
1599 return success();
1600 }
1601};
1602
1603struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1604 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1605 using ConvertOpToLLVMPattern<
1606 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1607 LogicalResult
1608 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1609 ConversionPatternRewriter &rewriter) const override {
1610 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1611 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1612 getTypeConverter()->convertType(op.getMatrixC().getType()));
1613 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1614 .getBody()
1615 .front();
1616 Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
1617 Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1618 SmallVector<Value> innerStructs;
1619 // Unpack the structs and set all values to zero
1620 for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1621 auto structType = cast<LLVM::LLVMStructType>(s);
1622 Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1623 for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1624 structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1625 zero, ArrayRef<int64_t>({i}));
1626 }
1627 innerStructs.push_back(structValue);
1628 }
1629 // Pack the inner structs into a single struct
1630 for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1631 packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
1632 packStruct, matrix, idx);
1633 }
1634 rewriter.replaceOp(op, packStruct);
1635 return success();
1636 }
1637};
1638
1639struct NVGPUTmaFenceOpLowering
1640 : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1641 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1642 LogicalResult
1643 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1644 ConversionPatternRewriter &rewriter) const override {
1645 MLIRContext *ctx = op.getContext();
1646 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1647 auto i32Ty = b.getI32Type();
1648 Value tensormapSize =
1649 LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
1650
1651 auto memscope =
1652 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1653
1654 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1655 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1656
1657 return success();
1658 }
1659};
1660
1661struct NVGPUTmaPrefetchOpLowering
1662 : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1663 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1664 LogicalResult
1665 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1666 ConversionPatternRewriter &rewriter) const override {
1667 rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1668 op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
1669 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1670 /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
1671 return success();
1672 }
1673};
1674
1675struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1676 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1677 LogicalResult
1678 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1679 ConversionPatternRewriter &rewriter) const override {
1680 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1681 auto i64Ty = b.getI64Type();
1682 auto f32Ty = b.getF32Type();
1683 VectorType inTy = op.getIn().getType();
1684 // apply rcp.approx.ftz.f on each element in vector.
1685 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1686 Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1687 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1688 for (int i = 0; i < numElems; i++) {
1689 Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
1690 Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1691 Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1692 ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1693 }
1694 return ret1DVec;
1695 };
1696 if (inTy.getRank() == 1) {
1697 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1698 return success();
1699 }
1701 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1702 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1703 OpAdaptor adaptor(operands);
1704 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1705 },
1706 rewriter);
1707 }
1708};
1709} // namespace
1710
1712 TypeConverter &typeConverter) {
1713 // NVVM uses alloca in the default address space to represent private
1714 // memory allocations, so drop private annotations. NVVM uses address
1715 // space 3 for shared memory. NVVM uses the default address space to
1716 // represent global memory.
1718 typeConverter, [](gpu::AddressSpace space) -> unsigned {
1719 switch (space) {
1720 case gpu::AddressSpace::Global:
1721 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
1722 case gpu::AddressSpace::Workgroup:
1723 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
1724 case gpu::AddressSpace::Private:
1725 return 0;
1726 }
1727 llvm_unreachable("unknown address space enum value");
1728 });
1729}
1730
1732 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1733 patterns.add<
1734 NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1735 NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1736 NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1737 NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1738 NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1739 NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1740 NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1741 NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1742 NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1743 NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1744 NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1745 NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1746 NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1747 NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1748 NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1749 NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1750 NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1751 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1752 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1753 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1754}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr int kWarpSize
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
FloatType getF32Type()
Definition Builders.cpp:43
IntegerType getI32Type()
Definition Builders.cpp:63
FloatType getF16Type()
Definition Builders.cpp:39
MLIRContext * getContext() const
Definition Builders.h:56
FloatType getF64Type()
Definition Builders.cpp:45
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isTF32() const
Definition Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const