MLIR 23.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
26#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34#define DEBUG_TYPE "nvgpu-to-nvvm"
35
36namespace mlir {
37#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38#include "mlir/Conversion/Passes.h.inc"
39} // namespace mlir
40
41using namespace mlir;
42
43/// Number of bits that needs to be excluded when building matrix descriptor for
44/// wgmma operations.
45constexpr int exclude4LSB = 4;
46
47/// GPU has 32 bit registers, this function truncates values when larger width
48/// is not needed.
50 Type type = value.getType();
51 assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
52 if (type.getIntOrFloatBitWidth() <= 32)
53 return value;
54 return LLVM::TruncOp::create(b, b.getI32Type(), value);
55}
56
57/// Returns the type for the intrinsic given the vectorResultType of the
58/// `gpu.mma.sync` operation.
59static Type inferIntrinsicResultType(Type vectorResultType) {
60 MLIRContext *ctx = vectorResultType.getContext();
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63 auto i32Ty = IntegerType::get(ctx, 32);
64 auto i32x2Ty = VectorType::get(2, i32Ty);
65 Type f64Ty = Float64Type::get(ctx);
66 Type f64x2Ty = VectorType::get(2, f64Ty);
67 Type f32Ty = Float32Type::get(ctx);
68 Type f32x2Ty = VectorType::get(2, f32Ty);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
71 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
72 }
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
75 ctx,
76 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
77 }
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
80 }
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
83 ctx,
84 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
85 }
86 if (a.getElementType() == VectorType::get(1, f32Ty)) {
87 return LLVM::LLVMStructType::getLiteral(
88 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
89 }
90 return vectorResultType;
91}
92
93/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
94/// always an LLVM struct) into a fragment that is compatible with the vector
95/// type of this operation. This involves extracting elements from the struct
96/// and inserting them into an LLVM array. These extra data-movement
97/// operations should be canonicalized away by the LLVM backend.
98static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
99 Type resultType, Value intrinsicResult,
100 RewriterBase &rewriter) {
101 MLIRContext *ctx = rewriter.getContext();
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
104 Type i32Ty = rewriter.getI32Type();
105 Type f32Ty = rewriter.getF32Type();
106 Type f64Ty = rewriter.getF64Type();
107 Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
108 Type i32x2Ty = VectorType::get(2, i32Ty);
109 Type f64x2Ty = VectorType::get(2, f64Ty);
110 Type f32x2Ty = VectorType::get(2, f32Ty);
111 Type f32x1Ty = VectorType::get(1, f32Ty);
112
113 auto makeConst = [&](int32_t index) -> Value {
114 return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
115 rewriter.getI32IntegerAttr(index));
116 };
117
118 if (arrayType) {
119 SmallVector<Value, 4> elements;
120
121 // The intrinsic returns 32-bit wide elements in a form which can be
122 // directly bitcasted and inserted into the result vector.
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (unsigned i = 0; i < structType.getBody().size(); i++) {
126 Value el =
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
128 el = rewriter.createOrFold<LLVM::BitcastOp>(
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
131 }
132 }
133
134 // The intrinsic returns i32, f64, and f32 values as individual scalars,
135 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
136 // need to extract them from the struct and pack them into the 64-bit wide
137 // rows of the vector result.
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
141
142 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
143 Value vec =
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
145 Value x1 =
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
148 i * 2 + 1);
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
150 x1, makeConst(0));
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
152 x2, makeConst(1));
153 elements.push_back(vec);
154 }
155 }
156
157 // Create the final vectorized result.
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159 for (const auto &el : llvm::enumerate(elements)) {
160 result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
161 el.index());
162 }
163 return result;
164 }
165
166 return intrinsicResult;
167}
168
169/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
170/// given as 2D `vectors` where the rows are 32b or 64b wide. The
171/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
172/// scalars of certain types. This function helps unpack the `vector` arguments
173/// and cast them to the types expected by `nvvm.mma.sync`.
175 Value operand,
176 NVVM::MMATypes operandPtxType) {
178 Type i32Ty = b.getI32Type();
179 Type f64Ty = b.getF64Type();
180 Type f32Ty = b.getF32Type();
181 Type i64Ty = b.getI64Type();
182 Type bf16x2Ty = VectorType::get(2, b.getBF16Type());
183 Type i8x4Ty = VectorType::get(4, b.getI8Type());
184 Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
185 Type f32x1Ty = VectorType::get(1, f32Ty);
186 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
187
188 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189 Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
190
191 // For 4xi8 vectors, the intrinsic expects these to be provided as i32
192 // scalar types.
193 if (arrayTy.getElementType() == i8x4Ty ||
194 arrayTy.getElementType() == i4x8Ty ||
195 (arrayTy.getElementType() == bf16x2Ty &&
196 operandPtxType == NVVM::MMATypes::bf16) ||
197 (arrayTy.getElementType() == f32x1Ty &&
198 operandPtxType == NVVM::MMATypes::tf32)) {
199 result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
200 continue;
201 }
202
203 // For some element types (i32, f32, f64), we need to unpack the inner
204 // vector/array type as well because the intrinsic expects individual
205 // scalars to be provided.
206 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
207 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
208 innerArrayTy.getElementType() == f64Ty ||
209 innerArrayTy.getElementType() == f32Ty)) {
210 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
211 idx < innerSize; idx++) {
212 result.push_back(LLVM::ExtractElementOp::create(
213 b, toUse,
214 LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
215 }
216 continue;
217 }
218 result.push_back(toUse);
219 }
220 return result;
221}
222
223/// Returns whether mbarrier object has shared memory address space.
224static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
225 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
226 barrierType.getMemorySpace()));
227}
228
229/// Returns the memory space attribute of the mbarrier object.
231 nvgpu::MBarrierGroupType barrierType) {
232 Attribute memorySpace = {};
233 if (isMbarrierShared(barrierType)) {
234 memorySpace =
235 IntegerAttr::get(IntegerType::get(context, 64),
236 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
237 }
238 return memorySpace;
239}
240
241/// Returns memref type of the mbarrier object. The type is defined in the
242/// MBarrierGroupType.
243MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
244 nvgpu::MBarrierGroupType barrierType) {
245 Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
246 MemRefLayoutAttrInterface layout;
247 return MemRefType::get({barrierType.getNumBarriers()},
248 IntegerType::get(context, 64), layout, memorySpace);
249}
250
251namespace {
252
253struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
254 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
255
256 LogicalResult
257 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter) const override {
259 MLIRContext *ctx = getContext();
260 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
261
262 // The result type of ldmatrix will always be a struct of 32bit integer
263 // registers if more than one 32bit value is returned. Otherwise, the result
264 // is a single i32. The result type of the GPU operation is always a vector
265 // of shape (NumRegisters, VectorRegister) where VectorRegister is the
266 // vector type of the result and always 32 bits long. We bitcast the result
267 // of the NVVM::LdMatrix to this vector type.
268 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
269 if (!vectorResultType) {
270 return failure();
271 }
272 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
273 vectorResultType.getElementType());
274
275 int64_t num32BitRegs = vectorResultType.getDimSize(0);
276
277 Type ldMatrixResultType;
278 if (num32BitRegs > 1) {
279 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
280 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
281 } else {
282 ldMatrixResultType = rewriter.getI32Type();
283 }
284
285 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
286 Value srcPtr =
287 getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
288 adaptor.getSrcMemref(), adaptor.getIndices());
289 auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
290 Value ldMatrixResult = NVVM::LdMatrixOp::create(
291 b, ldMatrixResultType, srcPtr,
292 /*num=*/op.getNumTiles(),
293 /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
294 : NVVM::MMALayout::row,
295 /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
296
297 // The ldmatrix operation returns either a single i32 value or a struct of
298 // i32 values. Here we unpack those values and cast them back to their
299 // actual vector type (still of width 32b) and repack them into a result
300 // struct.
301 Type finalResultType = typeConverter->convertType(vectorResultType);
302 Value result = LLVM::PoisonOp::create(b, finalResultType);
303 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
304 Value i32Register =
305 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
306 : ldMatrixResult;
307 Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
308 result = LLVM::InsertValueOp::create(b, result, casted, i);
309 }
310
311 rewriter.replaceOp(op, result);
312 return success();
313 }
314};
315
316/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
317/// enum).
318static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
319 Type elType = getElementTypeOrSelf(t);
320 if (elType.isInteger(8))
321 return NVVM::MMATypes::s8;
322 if (elType.isInteger(4))
323 return NVVM::MMATypes::s4;
324 if (elType.isF16())
325 return NVVM::MMATypes::f16;
326 if (elType.isBF16())
327 return NVVM::MMATypes::bf16;
328 if (elType.isF64())
329 return NVVM::MMATypes::f64;
330 if (elType.isF32())
331 return NVVM::MMATypes::tf32;
332 return failure();
333}
334
335struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
336 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
337
338 LogicalResult
339 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const override {
341 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
342 // Get the shapes of the MMAMatrix type being used. The shapes will
343 // choose which intrinsic this op will be lowered to.
344 VectorType aType = op.getMatrixA().getType();
345 VectorType bType = op.getMatrixA().getType();
346 VectorType cType = op.getMatrixC().getType();
347
348 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
349
350 // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
351 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
352 if (aType.getElementType().isF32() && !tf32Enabled)
353 return failure();
354
355 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
356 if (failed(ptxTypeA))
357 return op->emitOpError("failed to deduce operand PTX types");
358 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
359 if (failed(ptxTypeB))
360 return op->emitOpError("failed to deduce operand PTX types");
361 std::optional<NVVM::MMATypes> ptxTypeC =
362 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
363 /*isAccumulator=*/true);
364 if (!ptxTypeC)
365 return op->emitError(
366 "could not infer the PTX type for the accumulator/result");
367
368 // TODO: add an attribute to the op to customize this behavior.
369 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
370 if (isa<IntegerType>(aType.getElementType()))
371 overflow = NVVM::MMAIntOverflow::satfinite;
372
373 SmallVector<Value> matA =
374 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
375 SmallVector<Value> matB =
376 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
377 SmallVector<Value> matC =
378 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
379
380 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
381 Type intrinsicResTy = inferIntrinsicResultType(
382 typeConverter->convertType(op->getResultTypes()[0]));
383 Value intrinsicResult =
384 NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
385 /*shape=*/gemmShape,
386 /*b1Op=*/std::nullopt,
387 /*intOverflow=*/overflow,
388 /*multiplicandPtxTypes=*/
389 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
390 /*multiplicandLayouts=*/
391 std::array<NVVM::MMALayout, 2>{
392 NVVM::MMALayout::row, NVVM::MMALayout::col});
393 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
394 desiredRetTy, intrinsicResult,
395 rewriter));
396 return success();
397 }
398};
399
400struct ConvertNVGPUToNVVMPass
401 : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
402 using Base::Base;
403
404 void runOnOperation() override {
405 LowerToLLVMOptions options(&getContext());
406 RewritePatternSet patterns(&getContext());
407 LLVMTypeConverter converter(&getContext(), options);
408 IRRewriter rewriter(&getContext());
410
411 /// device-side async tokens cannot be materialized in nvvm. We just
412 /// convert them to a dummy i32 type in order to easily drop them during
413 /// conversion.
414 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
415 return converter.convertType(IntegerType::get(type.getContext(), 32));
416 });
417 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
418 Type elemType = type.getFragmented().getElementType();
419 int64_t sizeM = type.getFragmented().getDimSize(0);
420 int64_t sizeN = type.getFragmented().getDimSize(1);
421
422 unsigned numMembers;
423 if (elemType.isF32() || elemType.isInteger(32))
424 numMembers = sizeN / 2;
425 else if (elemType.isF16())
426 numMembers = sizeN / 4;
427 else
428 llvm_unreachable("unsupported type for warpgroup accumulator");
429
430 SmallVector<Type> innerStructBody;
431 for (unsigned i = 0; i < numMembers; i++)
432 innerStructBody.push_back(elemType);
433 auto innerStructType =
434 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
435
436 SmallVector<Type> structBody;
437 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
438 structBody.push_back(innerStructType);
439
440 auto convertedType =
441 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
442 return converter.convertType(convertedType);
443 });
444 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
445 return converter.convertType(IntegerType::get(type.getContext(), 64));
446 });
447 converter.addConversion(
448 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
449 return converter.convertType(IntegerType::get(type.getContext(), 64));
450 });
451 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
452 return converter.convertType(
453 nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
454 });
455 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
456 return LLVM::LLVMPointerType::get(type.getContext());
457 });
458 populateNVGPUToNVVMConversionPatterns(converter, patterns);
459 LLVMConversionTarget target(getContext());
460 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
461 target.addLegalDialect<::mlir::arith::ArithDialect>();
462 target.addLegalDialect<::mlir::memref::MemRefDialect>();
463 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
465 converter, patterns, target);
466 if (failed(applyPartialConversion(getOperation(), target,
467 std::move(patterns))))
468 signalPassFailure();
469 }
470};
471
472/// Returns the constraints for the sparse MMA inline assembly instruction.
473static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
474 unsigned matBSize,
475 unsigned matCSize) {
476 std::string str;
477 llvm::raw_string_ostream ss(str);
478 for (unsigned i = 0; i < matCSize; i++)
479 ss << "=r,";
480 for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
481 ss << "r,";
482 // The final operand is for the sparsity metadata.
483 // The sparsity selector appears as direct literal.
484 ss << "r";
485 return str;
486}
487
488/// Returns the string for the `mma.sp.sync` instruction that corresponds to
489/// the given parameters. Note that this function doesn't do any validation,
490/// it's expected that the provided parameters correspond to a valid
491/// instruction.
492static std::string buildMmaSparseAsmString(
493 const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
494 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
495 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
496 std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
497 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
498 return NVVM::stringifyMMATypes(ptxType);
499 };
500
501 std::string asmStr;
502 llvm::raw_string_ostream ss(asmStr);
503 ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
504 << shape[2] << ".row.col.";
505
506 if (overflow)
507 ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
508
509 ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
510 << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
511 unsigned asmArgIdx = 0;
512
513 // The operand string is structured into sections `{matC elements...},
514 // {matA elements...}, {matB elements...}, {matC elements}`.
515 for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
516 ss << "{";
517 for (unsigned i = 0; i < arrSize; i++)
518 ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
519 ss << "},";
520 }
521 ss << "$" << asmArgIdx++ << ",";
522 assert(metaDataSelector <= 1);
523 ss << "0x" << metaDataSelector << ";";
524 return asmStr;
525}
526
527/// Builds an inline assembly operation corresponding to the specified MMA
528/// sparse sync operation.
529static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
530 ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
531 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
532 std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
533 ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
534 int64_t metadataSelector, const std::array<int64_t, 3> &shape,
535 Type intrinsicResultType) {
536 auto asmDialectAttr =
537 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
538
539 const unsigned matASize = unpackedAData.size();
540 const unsigned matBSize = unpackedB.size();
541 const unsigned matCSize = unpackedC.size();
542
543 std::string asmStr = buildMmaSparseAsmString(
544 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
545 ptxTypeD, overflow, metadataSelector);
546 std::string constraintStr =
547 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
548
549 SmallVector<Value> asmVals;
550 asmVals.reserve(matASize + matBSize + matCSize + 1);
551 for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
552 llvm::append_range(asmVals, args);
553 asmVals.push_back(indexData);
554
555 return LLVM::InlineAsmOp::create(b,
556 /*resultTypes=*/intrinsicResultType,
557 /*operands=*/asmVals,
558 /*asm_string=*/asmStr,
559 /*constraints=*/constraintStr,
560 /*has_side_effects=*/true,
561 /*is_align_stack=*/false,
562 LLVM::TailCallKind::None,
563 /*asm_dialect=*/asmDialectAttr,
564 /*operand_attrs=*/ArrayAttr());
565}
566
567/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
568struct NVGPUMmaSparseSyncLowering
569 : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
570 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
571
572 LogicalResult
573 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
574 ConversionPatternRewriter &rewriter) const override {
575 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
576 // Get the shapes of the MMAMatrix type being used. The shapes will
577 // choose which intrinsic this op will be lowered to.
578 VectorType aType = op.getMatrixA().getType();
579 VectorType bType = op.getMatrixB().getType();
580 VectorType cType = op.getMatrixC().getType();
581
582 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
583 if (failed(ptxTypeA))
584 return op->emitOpError("failed to deduce operand PTX types");
585 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
586 if (failed(ptxTypeB))
587 return op->emitOpError("failed to deduce operand PTX types");
588 std::optional<NVVM::MMATypes> ptxTypeC =
589 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
590 /*isAccumulator=*/true);
591 if (!ptxTypeC)
592 return op->emitError(
593 "could not infer the PTX type for the accumulator/result");
594
595 // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
596 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
597 if (aType.getElementType().isF32() && !tf32Enabled)
598 return failure();
599
600 // TODO: add an attribute to the op to customize this behavior.
601 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
602 if (isa<IntegerType>(aType.getElementType()))
603 overflow = NVVM::MMAIntOverflow::satfinite;
604
605 SmallVector<Value> matA =
606 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
607 SmallVector<Value> matB =
608 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
609 SmallVector<Value> matC =
610 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
611
612 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
613 Type intrinsicResTy = inferIntrinsicResultType(
614 typeConverter->convertType(op->getResultTypes()[0]));
615
616 // Bitcast the sparse metadata from vector<2xf16> to an i32.
617 Value sparseMetadata = adaptor.getSparseMetadata();
618 if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
619 return op->emitOpError() << "Expected metadata type to be LLVM "
620 "VectorType of 2 i16 elements";
621 sparseMetadata =
622 LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
623
624 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
625 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
626 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
627 intrinsicResTy);
628 if (failed(intrinsicResult))
629 return failure();
630
631 assert((*intrinsicResult).getNumResults() == 1 &&
632 "expected inline asm op returns a single LLVM struct type");
633 rewriter.replaceOp(
634 op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
635 (*intrinsicResult)->getResult(0), rewriter));
636 return success();
637 }
638};
639
640struct NVGPUAsyncCopyLowering
641 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
642 using ConvertOpToLLVMPattern<
643 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
644
645 LogicalResult
646 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
647 ConversionPatternRewriter &rewriter) const override {
648 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
649 Location loc = op.getLoc();
650 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
651 Value dstPtr =
652 getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
653 adaptor.getDst(), adaptor.getDstIndices());
654 FailureOr<unsigned> dstAddressSpace =
655 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
656 if (failed(dstAddressSpace))
657 return rewriter.notifyMatchFailure(
658 loc, "destination memref address space not convertible to integer");
659
660 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
661 FailureOr<unsigned> srcAddressSpace =
662 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
663 if (failed(srcAddressSpace))
664 return rewriter.notifyMatchFailure(
665 loc, "source memref address space not convertible to integer");
666
667 Value scrPtr =
668 getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
669 adaptor.getSrcIndices());
670 // Intrinsics takes a global pointer so we need an address space cast.
671 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
672 op->getContext(), static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
673 scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
674 int64_t dstElements = adaptor.getDstElements().getZExtValue();
675 int64_t sizeInBytes =
676 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
677 // When the optional SrcElements argument is *not* present, the regular
678 // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
679 // memory) to fill DstElements number of elements in the destination
680 // (shared memory).
681 Value srcBytes = adaptor.getSrcElements();
682 if (srcBytes) {
683 // When the optional SrcElements argument is present, the source (global
684 // memory) of CpAsyncOp is read only for SrcElements number of elements.
685 // The rest of the DstElements in the destination (shared memory) are
686 // filled with zeros.
687 Value c3I32 =
688 LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
689 Value bitwidth = LLVM::ConstantOp::create(
690 b, b.getI32Type(),
691 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
692 Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
693 srcBytes = LLVM::LShrOp::create(
694 b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
695 }
696 // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
697 // 16 dst bytes.
698 NVVM::LoadCacheModifierKind cacheModifier =
699 (op.getBypassL1().value_or(false) && sizeInBytes == 16)
700 ? NVVM::LoadCacheModifierKind::CG
701 : NVVM::LoadCacheModifierKind::CA;
702
703 NVVM::CpAsyncOp::create(
704 b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
705 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
706 srcBytes);
707
708 // Drop the result token.
709 Value zero =
710 LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
711 rewriter.getI32IntegerAttr(0));
712 rewriter.replaceOp(op, zero);
713 return success();
714 }
715};
716
717struct NVGPUAsyncCreateGroupLowering
718 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
719 using ConvertOpToLLVMPattern<
720 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
721
722 LogicalResult
723 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
724 ConversionPatternRewriter &rewriter) const override {
725 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
726 // Drop the result token.
727 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
728 IntegerType::get(op.getContext(), 32),
729 rewriter.getI32IntegerAttr(0));
730 rewriter.replaceOp(op, zero);
731 return success();
732 }
733};
734
735struct NVGPUAsyncWaitLowering
736 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
737 using ConvertOpToLLVMPattern<
738 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
739
740 LogicalResult
741 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
742 ConversionPatternRewriter &rewriter) const override {
743 // If numGroup is not present pick 0 as a conservative correct value.
744 int32_t numGroups = adaptor.getNumGroups().value_or(0);
745 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
746 rewriter.eraseOp(op);
747 return success();
748 }
749};
750
751/// Creates mbarrier object in shared memory
752struct NVGPUMBarrierCreateLowering
753 : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
754 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
755
756 template <typename moduleT>
757 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
758 Operation *funcOp, moduleT moduleOp,
759 MemRefType barrierType) const {
760 SymbolTable symbolTable(moduleOp);
761 OpBuilder::InsertionGuard guard(rewriter);
762 rewriter.setInsertionPoint(&moduleOp.front());
763 auto global = memref::GlobalOp::create(
764 rewriter, funcOp->getLoc(), "__mbarrier",
765 /*sym_visibility=*/rewriter.getStringAttr("private"),
766 /*type=*/barrierType,
767 /*initial_value=*/ElementsAttr(),
768 /*constant=*/false,
769 /*alignment=*/rewriter.getI64IntegerAttr(8));
770 symbolTable.insert(global);
771 return global;
772 }
773
774 LogicalResult
775 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
776 ConversionPatternRewriter &rewriter) const override {
777 Operation *funcOp = op->getParentOp();
778 MemRefType barrierType = nvgpu::getMBarrierMemrefType(
779 rewriter.getContext(), op.getBarriers().getType());
780
781 memref::GlobalOp global;
782 if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
783 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
784 else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
785 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
786
787 rewriter.setInsertionPoint(op);
788 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
789 global.getName());
790 return success();
791 }
792};
793
794/// Base class for lowering mbarrier operations to nvvm intrinsics.
795template <typename SourceOp>
796struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
797public:
798 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
799 /// Returns the base pointer of the mbarrier object.
800 Value getMbarrierPtr(ImplicitLocOpBuilder &b,
801 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
802 Value mbarId,
803 ConversionPatternRewriter &rewriter) const {
804 MemRefType mbarrierMemrefType =
805 nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
807 rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
808 }
809};
810
811struct NVGPUMBarrierGetLowering
812 : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
813 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
814
815 LogicalResult
816 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter) const override {
818 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
819 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
820 rewriter.setInsertionPoint(op);
821 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
822 adaptor.getMbarId(), rewriter);
823 Type resType = op.getMbarrierPointer().getType();
824 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
825 return success();
826 }
827};
828
829/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
830struct NVGPUMBarrierInitLowering
831 : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
832 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
833
834 LogicalResult
835 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
836 ConversionPatternRewriter &rewriter) const override {
837 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
838 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
839 rewriter.setInsertionPoint(op);
840 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
841 adaptor.getMbarId(), rewriter);
842 Value count = truncToI32(b, adaptor.getCount());
843 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
844 adaptor.getPredicate());
845 return success();
846 }
847};
848
849/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
850struct NVGPUMBarrierArriveLowering
851 : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
852 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
853 LogicalResult
854 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
855 ConversionPatternRewriter &rewriter) const override {
856 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
857 Value barrier =
858 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
859 adaptor.getMbarId(), rewriter);
860 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, barrier);
861 return success();
862 }
863};
864
865/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
866/// `nvvm.mbarrier.arrive.nocomplete`
867struct NVGPUMBarrierArriveNoCompleteLowering
868 : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
869 using MBarrierBasePattern<
870 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
871 LogicalResult
872 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
873 ConversionPatternRewriter &rewriter) const override {
874 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
875 Value barrier =
876 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
877 adaptor.getMbarId(), rewriter);
878 Type tokenType = getTypeConverter()->convertType(
879 nvgpu::MBarrierTokenType::get(op->getContext()));
880 Value count = truncToI32(b, adaptor.getCount());
881 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
882 op, tokenType, barrier, count);
883 return success();
884 }
885};
886
887/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
888struct NVGPUMBarrierTestWaitLowering
889 : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
890 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
891 LogicalResult
892 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter) const override {
894 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
895 Value barrier =
896 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
897 adaptor.getMbarId(), rewriter);
898 Type retType = rewriter.getI1Type();
899 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
900 adaptor.getToken());
901 return success();
902 }
903};
904
905struct NVGPUMBarrierArriveExpectTxLowering
906 : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
907 using MBarrierBasePattern<
908 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
909 LogicalResult
910 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
911 ConversionPatternRewriter &rewriter) const override {
912 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
913 Value barrier =
914 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
915 adaptor.getMbarId(), rewriter);
916 Value txcount = truncToI32(b, adaptor.getTxcount());
917 NVVM::MBarrierArriveExpectTxOp::create(
918 rewriter, op->getLoc(), barrier, txcount, // barrier and txcount
919 NVVM::MemScopeKind::CTA, // default scope is CTA
920 false, // relaxed-semantics is false
921 adaptor.getPredicate());
922 rewriter.eraseOp(op);
923 return success();
924 }
925};
926
927struct NVGPUMBarrierTryWaitParityLowering
928 : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
929 using MBarrierBasePattern<
930 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
931 LogicalResult
932 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
933 ConversionPatternRewriter &rewriter) const override {
934 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
935 Value barrier =
936 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
937 adaptor.getMbarId(), rewriter);
938 Value ticks = truncToI32(b, adaptor.getTicks());
939 Value phase =
940 LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
941 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
942 phase, ticks);
943 return success();
944 }
945};
946
947struct NVGPUTmaAsyncLoadOpLowering
948 : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
949 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
950 LogicalResult
951 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
952 ConversionPatternRewriter &rewriter) const override {
953 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
954 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
955 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
956 adaptor.getDst(), {});
957 // Intrinsics takes a shared-cluster pointer so we need an
958 // address space cast from 3 to 7.
959 // TODO: Introduce AS(7) in NVGPU.
960 auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
961 op->getContext(),
962 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
963 dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);
964
965 Value barrier =
966 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
967 adaptor.getMbarId(), rewriter);
968
969 SmallVector<Value> coords = adaptor.getCoordinates();
970 for (auto [index, value] : llvm::enumerate(coords)) {
971 coords[index] = truncToI32(b, value);
972 }
973
974 // TODO: Enhance the NVGPU Op for other modes too
975 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
976 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
977 ValueRange{}, adaptor.getMulticastMask(), Value{},
978 NVVM::TMALoadMode::TILE, // default is TILE mode
979 false, // default is cluster-scope
980 nullptr, // default is no cta-group
981 adaptor.getPredicate());
982 return success();
983 }
984};
985
986struct NVGPUTmaAsyncStoreOpLowering
987 : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
988 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
989 LogicalResult
990 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
991 ConversionPatternRewriter &rewriter) const override {
992 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
993 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
994 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
995 adaptor.getSrc(), {});
996 SmallVector<Value> coords = adaptor.getCoordinates();
997 for (auto [index, value] : llvm::enumerate(coords)) {
998 coords[index] = truncToI32(b, value);
999 }
1000
1001 // TODO: Enhance the NVGPU Op for other modes too
1002 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1003 op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1004 NVVM::TMAStoreMode::TILE, // default is TILE mode
1005 adaptor.getPredicate());
1006 return success();
1007 }
1008};
1009
1010struct NVGPUGenerateWarpgroupDescriptorLowering
1011 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1012 using ConvertOpToLLVMPattern<
1013 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1014
1015 LogicalResult
1016 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1017 ConversionPatternRewriter &rewriter) const override {
1018
1019 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1020
1021 nvgpu::TensorMapSwizzleKind swizzleKind =
1022 op.getTensorMap().getType().getSwizzle();
1023
1024 unsigned layout =
1025 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1026 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1027 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1028 : 1;
1029 unsigned swizzle =
1030 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1031 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1032 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1033 : 0;
1034
1035 auto ti64 = b.getIntegerType(64);
1036 auto makeConst = [&](uint64_t index) -> Value {
1037 return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
1038 };
1039 auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1040 return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1041 };
1042 auto shiftRight = [&](Value value, unsigned shift) -> Value {
1043 return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1044 };
1045 auto insertBit = [&](Value desc, Value val, int startBit) {
1046 return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1047 };
1048
1049 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1050 uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1051 uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1052 uint64_t offsetVal = 0;
1053
1054 Value strideDim = makeConst(strideDimVal);
1055 Value leadDim = makeConst(leadDimVal);
1056
1057 Value baseAddr = getStridedElementPtr(
1058 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1059 adaptor.getTensor(), {});
1060 Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1061 // Just use 14 bits for base address
1062 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1063
1064 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1065 startLeadBit = 16, startBaseAddrBit = 0;
1066 Value dsc = makeConst(0);
1067 // // [62,64) swizzle type
1068 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1069 // // [49,52) base_offset
1070 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1071 // // [32,46) stride
1072 dsc = insertBit(dsc, strideDim, startStrideBit);
1073 // // [16,30) leading dimension
1074 dsc = insertBit(dsc, leadDim, startLeadBit);
1075 // // [0,14) start_address
1076 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1077
1078 LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
1079 << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
1080 << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
1081 << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1082 << ")\n start_addr : " << baseAddr;
1083
1084 rewriter.replaceOp(op, dsc);
1085 return success();
1086 }
1087};
1088
1089static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1090 return LLVM::ConstantOp::create(b, b.getIntegerType(64),
1091 b.getI32IntegerAttr(index));
1092}
1093
1094/// Returns a Value that holds data type enum that is expected by CUDA driver.
1095static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1096 // Enum is from CUDA driver API
1097 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1098 enum CUtensorMapDataTypeEnum {
1099 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1100 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1101 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1102 CU_TENSOR_MAP_DATA_TYPE_INT32,
1103 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1104 CU_TENSOR_MAP_DATA_TYPE_INT64,
1105 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1106 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1107 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1108 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1109 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1110 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1111 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1112 };
1113
1114 if (type.isUnsignedInteger(8))
1115 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1116 if (type.isUnsignedInteger(16))
1117 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1118 if (type.isUnsignedInteger(32))
1119 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1120 if (type.isUnsignedInteger(64))
1121 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1122 if (type.isSignlessInteger(32))
1123 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1124 if (type.isSignlessInteger(64))
1125 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1126 if (type.isF16())
1127 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1128 if (type.isF32())
1129 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1130 if (type.isF64())
1131 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1132 if (type.isBF16())
1133 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1134
1135 llvm_unreachable("Not supported data type");
1136}
1137
1138struct NVGPUTmaCreateDescriptorOpLowering
1139 : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1140 using ConvertOpToLLVMPattern<
1141 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1142 LogicalResult
1143 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1144 ConversionPatternRewriter &rewriter) const override {
1145 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1146 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1147 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1148
1149 Value tensorElementType =
1150 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1151 auto promotedOperands = getTypeConverter()->promoteOperands(
1152 b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1153
1154 Value boxArrayPtr = LLVM::AllocaOp::create(
1155 b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1156 for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1157 Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1158 boxArrayPtr, makeI64Const(b, index));
1159 LLVM::StoreOp::create(b, value, gep);
1160 }
1161
1162 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1163 // Set Arguments for the function call
1164 SmallVector<Value> arguments;
1165 arguments.push_back(promotedOperands[0]); // rank
1166 arguments.push_back(promotedOperands[1]); // descriptor
1167 arguments.push_back(tensorElementType); // data type
1168 arguments.push_back(
1169 makeI64Const(b, (int)desc.getInterleave())); // interleave
1170 arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1171 arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1172 arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1173 arguments.push_back(boxArrayPtr); // box dimensions
1174
1175 // Set data types of the arguments
1176 SmallVector<Type> argTypes = {
1177 llvmInt64Type, /* int64_t tensorRank */
1178 llvmPointerType, /* ptr */
1179 llvmInt64Type, /* int64_t */
1180 llvmInt64Type, /* int64_t */
1181 llvmInt64Type, /* int64_t */
1182 llvmInt64Type, /* int64_t */
1183 llvmInt64Type, /* int64_t */
1184 llvmPointerType /* ptr */
1185 };
1186 FunctionCallBuilder hostRegisterCallBuilder = {
1187 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1188 Value tensorMap =
1189 hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1190
1191 rewriter.replaceOp(op, tensorMap);
1192 return success();
1193 }
1194};
1195
1196struct NVGPUWarpgroupMmaOpLowering
1197 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1198 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1199
1200 /// This is a helper class to generate required NVVM Ops for warp-group level
1201 /// matrix multiplication.
1202 /// When the given GEMM shape is larger than the shape of
1203 /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1204 /// Op(s), group and execute them asynchronously. The class also handles
1205 /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1206 /// create descriptors for each instruction.
1207 ///
1208 /// For example this is the case when the shape of GEMM is 128x128x128
1209 ///
1210 /// nvvm.wgmma.fence.aligned
1211 ///
1212 /// nvvm.wgmma.mma.async descA, descB
1213 /// iterate(descA, descB)
1214 /// nvvm.wgmma.mma.async descA, descB
1215 /// [6x times more]
1216 ///
1217 /// nvvm.wgmma.group.sync.aligned
1218 /// nvvm.wgmma.wait.group.sync [groupId]
1219 ///
1220 class WarpgroupGemm {
1221 nvgpu::WarpgroupMmaOp op;
1222 ImplicitLocOpBuilder b;
1223 OpAdaptor adaptor;
1224
1225 // Entire shape of the given Op
1226 int64_t totalM, totalN, totalK;
1227
1228 // Shape of one wgmma instruction
1229 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1230
1231 // Iteration counts for GEMM
1232 int iterationM = 0, iterationN = 0, iterationK = 0;
1233
1234 /// The function returns the shape of wgmma instruction that is defined in
1235 /// PTX programming guide.
1236 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1237 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1238 wgmmaM = 64;
1239 wgmmaN = sizeN;
1240 if (inputElemType.isTF32()) {
1241 wgmmaK = 8;
1242 } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1243 wgmmaK = 16;
1244 } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1245 inputElemType.isInteger(16)) {
1246 wgmmaK = 32;
1247 } else if (inputElemType.isInteger(1)) {
1248 wgmmaK = 256;
1249 } else {
1250 llvm_unreachable("msg: not supported K shape");
1251 }
1252 LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1253 << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
1254 }
1255
1256 /// Generates WGMMATypesAttr from MLIR Type
1257 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1258 bool useF32 = false) const {
1259 auto getWgmmaType = [=](Type elemType) {
1260 if (elemType.isF32() || elemType.isTF32())
1261 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1262 if (elemType.isF16())
1263 return NVVM::WGMMATypes::f16;
1264 if (elemType.isBF16())
1265 return NVVM::WGMMATypes::bf16;
1266 if (isa<Float8E4M3FNType>(elemType))
1267 return NVVM::WGMMATypes::e4m3;
1268 if (isa<Float8E5M2Type>(elemType))
1269 return NVVM::WGMMATypes::e5m2;
1270 if (elemType.isInteger(1))
1271 return NVVM::WGMMATypes::b1;
1272 if (elemType.isInteger(8))
1273 return NVVM::WGMMATypes::s8;
1274 if (elemType.isUnsignedInteger(8))
1275 return NVVM::WGMMATypes::u8;
1276 if (elemType.isInteger(32))
1277 return NVVM::WGMMATypes::s32;
1278 llvm_unreachable("unsupported type");
1279 };
1280 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1281 }
1282
1283 /// Generates layout attribute for the input matrix for wgmma instruction
1284 NVVM::MMALayoutAttr
1285 generateWgmmaLayout(std::optional<bool> transpose) const {
1286 if (transpose.value_or(false))
1287 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1288 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1289 }
1290
1291 /// Generates shape attribute for wgmma instruction
1292 NVVM::MMAShapeAttr generateWgmmaShape() const {
1293 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1294 }
1295
1296 /// Generates scale attributes of output matrix for wgmma instruction
1297 NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1298 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1299 NVVM::WGMMAScaleOut::one);
1300 }
1301 /// Generates scale attributes of input matrix for wgmma instruction
1302 NVVM::WGMMAScaleInAttr generateScaleIn() const {
1303 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1304 NVVM::WGMMAScaleIn::one);
1305 }
1306
1307 /// Basic function to generate Add
1308 Value makeAdd(Value lhs, Value rhs) {
1309 return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1310 };
1311
1312 /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1313 /// Currently, it only handles row-major.
1314 ///
1315 /// It moves the pointer like below for [128][64] size:
1316 /// +2 +4 +6
1317 /// ↓ ↓ ↓
1318 /// descA ---> +--+--+--+--+
1319 /// |->|->|->|->|
1320 /// | | | | |
1321 /// | | | | |
1322 /// | | | | |
1323 /// descA+512---> +-----------+
1324 /// | | | | |
1325 /// | | | | |
1326 /// | | | | |
1327 /// | | | | |
1328 /// +-----------+
1329 ///
1330 Value iterateDescriptorA(Value desc, int i, int j, int k) {
1331 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1332 Type elemA = matrixTypeA.getElementType();
1333 int byte = elemA.getIntOrFloatBitWidth() / 8;
1334 int tileShapeA = matrixTypeA.getDimSize(1);
1335 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1336 incrementVal = incrementVal >> exclude4LSB;
1337 LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
1338 << "] [wgmma descriptors] Descriptor A + " << incrementVal
1339 << " | \t ";
1340 if (!incrementVal)
1341 return desc;
1342 return makeAdd(desc, makeI64Const(b, incrementVal));
1343 }
1344
1345 /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1346 /// Currently, it only handles column-major.
1347 ///
1348 /// It moves the pointer like below for [128][64] size:
1349 /// descB ---> +--+--+--+--+--+--+--+--+
1350 /// |↓ | | | | | | | |
1351 /// |↓ | | | | | | | |
1352 /// |↓ | | | | | | | |
1353 /// |↓ | | | | | | | |
1354 /// +--+--+--+--+--+--+--+--+
1355 ///
1356 Value iterateDescriptorB(Value desc, int i, int j, int k) {
1357 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1358 Type elemB = matrixTypeB.getElementType();
1359 int byte = elemB.getIntOrFloatBitWidth() / 8;
1360 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1361 incrementVal = incrementVal >> exclude4LSB;
1362 LDBG() << "Descriptor B + " << incrementVal;
1363 if (!incrementVal)
1364 return desc;
1365 return makeAdd(desc, makeI64Const(b, incrementVal));
1366 }
1367
1368 /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1369 /// descriptors and arranges them based on induction variables: i, j, and k.
1370 Value generateWgmma(int i, int j, int k, Value matrixC) {
1371 LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1372 << "(A[" << (iterationM * wgmmaM) << ":"
1373 << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
1374 << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
1375 << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
1376 << "][" << 0 << ":" << wgmmaN << "])";
1377
1378 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1379 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1380
1381 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1382 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1383
1384 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1385 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1386
1387 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1388 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1389
1390 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1391 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1392 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1393 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1394 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1395
1396 auto overflow = NVVM::MMAIntOverflowAttr::get(
1397 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1398
1399 return NVVM::WgmmaMmaAsyncOp::create(
1400 b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
1401 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1402 overflow);
1403 }
1404
1405 /// Generates multiple wgmma instructions to complete the given GEMM shape
1406 Value generateWgmmaGroup() {
1407 Value wgmmaResult =
1408 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1409
1410 // Perform GEMM
1411 SmallVector<Value> wgmmaResults;
1412 for (int i = 0; i < iterationM; ++i) {
1413 Value matrixC =
1414 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1415 for (int j = 0; j < iterationN; ++j)
1416 for (int k = 0; k < iterationK; ++k)
1417 matrixC = generateWgmma(i, j, k, matrixC);
1418 wgmmaResults.push_back(matrixC);
1419 }
1420 for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1421 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
1422 wgmmaResult, matrix, idx);
1423 }
1424 return wgmmaResult;
1425 }
1426
1427 public:
1428 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1429 OpAdaptor adaptor)
1430 : op(op), b(b), adaptor(adaptor) {
1431 // Find the entire GEMM Shape
1432 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1433 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1434 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1435 LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
1436 << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
1437 << "] ---===";
1438
1439 // Find the shape for one wgmma instruction
1440 findWgmmaShape(
1441 totalM, totalN,
1442 op.getDescriptorA().getType().getTensor().getElementType());
1443
1444 // Iterations counts to complete the given shape with wgmma shape
1445 iterationM = totalM / wgmmaM;
1446 iterationN = totalN / wgmmaN;
1447 iterationK = totalK / wgmmaK;
1448 }
1449
1450 /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1451 /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1452 /// instructions and group synchronization, as well as waiting
1453 /// (WgmmaGroupSyncAlignedOp) for group synchronization
1454 /// (WgmmaWaitGroupSyncOp) after the instructions.
1455 Value generateWarpgroupMma() {
1456 NVVM::WgmmaFenceAlignedOp::create(b);
1457 Value wgmmaResult = generateWgmmaGroup();
1458 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1459 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1460 return wgmmaResult;
1461 }
1462 };
1463 LogicalResult
1464 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1465 ConversionPatternRewriter &rewriter) const override {
1466 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1467
1468 // Step 1. Build a helper class
1469 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1470
1471 // Step 2. Get the entire GEMM Shape
1472 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1473
1474 // Step 3. Replace fragmented result struct with the op results
1475 rewriter.replaceOp(op, wgmmaResult);
1476 return success();
1477 }
1478};
1479
1480struct NVGPUWarpgroupMmaStoreOpLowering
1481 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1482 using ConvertOpToLLVMPattern<
1483 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1484
1485 /// This function stores a fragmented register matrix owned by a warp group
1486 /// (128 threads) into a memref. Each thread has 64 registers, each the size
1487 /// of a struct.
1488 /// Here is what each threads (T) holds, each `d` is struct value with a
1489 /// number.
1490 ///
1491 /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1492 /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1493 /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1494 /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1495 /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1496 ///
1497 /// Matrix-D:
1498 /// +______________________________________________________________________+
1499 /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1500 /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1501 /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1502 /// ..| .........|.........|.........|.........|........|...........|........|
1503 /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1504 /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1505 /// ..| .........|.........|.........|.........|........|...........|........|
1506 /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1507 /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1508 /// ..| .........|.........|.........|.........|........|...........|........|
1509 /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1510 /// ..| .........|.........|.........|.........|........|...........|........|
1511 /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1512 /// ..| .........|.........|.........|.........|........|...........|........|
1513 /// +______________________________________________________________________+
1514 ///
1515 /// \param rewriter: The pattern rewriter.
1516 /// \param matrixD: Result of the warp-group MMA operation (fragmented
1517 /// matrix). It is holded by a thread and a struct with 64 elements.
1518 /// \param dstMemref: The memref where the registers will be stored.
1519 /// \param offset: the offset within the memref where the registers will be
1520 /// stored.
1521 void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1522 TypedValue<MemRefType> dstMemref,
1523 int offset) const {
1524 Type i32 = b.getI32Type();
1525
1526 auto makeConst = [&](int32_t index) -> Value {
1527 return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
1528 };
1529 Value c1 = makeConst(1);
1530 Value c2 = makeConst(2);
1531 Value c4 = makeConst(4);
1532 Value c8 = makeConst(8);
1533 Value c16 = makeConst(16);
1534 Value warpSize = makeConst(kWarpSize);
1535
1536 auto makeMul = [&](Value lhs, Value rhs) -> Value {
1537 return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
1538 };
1539 auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1540 return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1541 };
1542
1543 auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1545 Type it = b.getIndexType();
1546 Value idx = arith::IndexCastOp::create(b, it, x);
1547 Value idy0 = arith::IndexCastOp::create(b, it, y);
1548 Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1549 Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1550 Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1551 memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
1552 memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
1553 };
1554
1555 Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1556 Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1557 Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1558 Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1559 Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1560
1561 Value tj = makeMul(lane4modId, c2);
1562 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1563 if (offset)
1564 ti = makeAdd(ti, makeConst(offset));
1565
1566 auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1567
1568 // Number of 32-bit registers owns per thread
1569 constexpr unsigned numAdjacentRegisters = 2;
1570 // Number of 8x8 matrices one below another per warp
1571 constexpr unsigned numStackedMatrices = 2;
1572
1573 size_t storeCount = (structType.getBody().size() /
1574 (numStackedMatrices * numAdjacentRegisters));
1575
1576 for (size_t i = 0; i < numStackedMatrices; ++i) {
1577 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1578 for (size_t j = 0; j < storeCount; ++j) {
1579 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1580 size_t structIndex = (i * numAdjacentRegisters) +
1581 (j * (numStackedMatrices * numAdjacentRegisters));
1582 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1583 }
1584 }
1585 }
1586
1587 LogicalResult
1588 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1589 ConversionPatternRewriter &rewriter) const override {
1590 int offset = 0;
1591 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1592 Value matriDValue = adaptor.getMatrixD();
1593 auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1594 for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1595 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1596 Value innerStructValue =
1597 LLVM::ExtractValueOp::create(b, matriDValue, idx);
1598 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1599 offset += structType.getBody().size();
1600 }
1601 rewriter.eraseOp(op);
1602 return success();
1603 }
1604};
1605
1606struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1607 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1608 using ConvertOpToLLVMPattern<
1609 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1610 LogicalResult
1611 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1612 ConversionPatternRewriter &rewriter) const override {
1613 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1614 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1615 getTypeConverter()->convertType(op.getMatrixC().getType()));
1616 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1617 .getBody()
1618 .front();
1619 Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
1620 Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1621 SmallVector<Value> innerStructs;
1622 // Unpack the structs and set all values to zero
1623 for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1624 auto structType = cast<LLVM::LLVMStructType>(s);
1625 Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1626 for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1627 structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1628 zero, ArrayRef<int64_t>({i}));
1629 }
1630 innerStructs.push_back(structValue);
1631 }
1632 // Pack the inner structs into a single struct
1633 for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1634 packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
1635 packStruct, matrix, idx);
1636 }
1637 rewriter.replaceOp(op, packStruct);
1638 return success();
1639 }
1640};
1641
1642struct NVGPUTmaFenceOpLowering
1643 : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1644 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1645 LogicalResult
1646 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1647 ConversionPatternRewriter &rewriter) const override {
1648 MLIRContext *ctx = op.getContext();
1649 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1650 auto i32Ty = b.getI32Type();
1651 Value tensormapSize =
1652 LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
1653
1654 auto memscope =
1655 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1656
1657 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1658 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1659
1660 return success();
1661 }
1662};
1663
1664struct NVGPUTmaPrefetchOpLowering
1665 : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1666 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1667 LogicalResult
1668 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1669 ConversionPatternRewriter &rewriter) const override {
1670 rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1671 op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
1672 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1673 /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
1674 return success();
1675 }
1676};
1677
1678struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1679 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1680 LogicalResult
1681 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1682 ConversionPatternRewriter &rewriter) const override {
1683 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1684 auto i64Ty = b.getI64Type();
1685 auto f32Ty = b.getF32Type();
1686 VectorType inTy = op.getIn().getType();
1687 // apply rcp.approx.ftz.f on each element in vector.
1688 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1689 Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1690 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1691 for (int i = 0; i < numElems; i++) {
1692 Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
1693 Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1694 Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1695 ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1696 }
1697 return ret1DVec;
1698 };
1699 if (inTy.getRank() == 1) {
1700 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1701 return success();
1702 }
1704 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1705 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1706 OpAdaptor adaptor(operands);
1707 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1708 },
1709 rewriter);
1710 }
1711};
1712} // namespace
1713
1715 TypeConverter &typeConverter) {
1716 // NVVM uses alloca in the default address space to represent private
1717 // memory allocations, so drop private annotations. NVVM uses address
1718 // space 3 for shared memory. NVVM uses the default address space to
1719 // represent global memory.
1721 typeConverter, [](gpu::AddressSpace space) -> unsigned {
1722 switch (space) {
1723 case gpu::AddressSpace::Global:
1724 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
1725 case gpu::AddressSpace::Workgroup:
1726 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
1727 case gpu::AddressSpace::Private:
1728 return 0;
1729 case gpu::AddressSpace::Constant:
1730 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Constant);
1731 }
1732 llvm_unreachable("unknown address space enum value");
1733 });
1734}
1735
1737 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1738 patterns.add<
1739 NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1740 NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1741 NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1742 NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1743 NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1744 NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1745 NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1746 NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1747 NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1748 NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1749 NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1750 NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1751 NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1752 NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1753 NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1754 NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1755 NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1756 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1757 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1758 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1759}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr int kWarpSize
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getI32Type()
Definition Builders.cpp:67
FloatType getF16Type()
Definition Builders.cpp:43
MLIRContext * getContext() const
Definition Builders.h:56
FloatType getF64Type()
Definition Builders.cpp:49
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:66
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isTF32() const
Definition Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const