MLIR 22.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
26#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34#define DEBUG_TYPE "nvgpu-to-nvvm"
35
36namespace mlir {
37#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38#include "mlir/Conversion/Passes.h.inc"
39} // namespace mlir
40
41using namespace mlir;
42
43/// Number of bits that needs to be excluded when building matrix descriptor for
44/// wgmma operations.
45constexpr int exclude4LSB = 4;
46
47/// GPU has 32 bit registers, this function truncates values when larger width
48/// is not needed.
50 Type type = value.getType();
51 assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
52 if (type.getIntOrFloatBitWidth() <= 32)
53 return value;
54 return LLVM::TruncOp::create(b, b.getI32Type(), value);
55}
56
57/// Returns the type for the intrinsic given the vectorResultType of the
58/// `gpu.mma.sync` operation.
59static Type inferIntrinsicResultType(Type vectorResultType) {
60 MLIRContext *ctx = vectorResultType.getContext();
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63 auto i32Ty = IntegerType::get(ctx, 32);
64 auto i32x2Ty = VectorType::get(2, i32Ty);
65 Type f64Ty = Float64Type::get(ctx);
66 Type f64x2Ty = VectorType::get(2, f64Ty);
67 Type f32Ty = Float32Type::get(ctx);
68 Type f32x2Ty = VectorType::get(2, f32Ty);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
71 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
72 }
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
75 ctx,
76 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
77 }
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
80 }
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
83 ctx,
84 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
85 }
86 if (a.getElementType() == VectorType::get(1, f32Ty)) {
87 return LLVM::LLVMStructType::getLiteral(
88 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
89 }
90 return vectorResultType;
91}
92
93/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
94/// always an LLVM struct) into a fragment that is compatible with the vector
95/// type of this operation. This involves extracting elements from the struct
96/// and inserting them into an LLVM array. These extra data-movement
97/// operations should be canonicalized away by the LLVM backend.
98static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
99 Type resultType, Value intrinsicResult,
100 RewriterBase &rewriter) {
101 MLIRContext *ctx = rewriter.getContext();
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
104 Type i32Ty = rewriter.getI32Type();
105 Type f32Ty = rewriter.getF32Type();
106 Type f64Ty = rewriter.getF64Type();
107 Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
108 Type i32x2Ty = VectorType::get(2, i32Ty);
109 Type f64x2Ty = VectorType::get(2, f64Ty);
110 Type f32x2Ty = VectorType::get(2, f32Ty);
111 Type f32x1Ty = VectorType::get(1, f32Ty);
112
113 auto makeConst = [&](int32_t index) -> Value {
114 return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
115 rewriter.getI32IntegerAttr(index));
116 };
117
118 if (arrayType) {
119 SmallVector<Value, 4> elements;
120
121 // The intrinsic returns 32-bit wide elements in a form which can be
122 // directly bitcasted and inserted into the result vector.
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (unsigned i = 0; i < structType.getBody().size(); i++) {
126 Value el =
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
128 el = rewriter.createOrFold<LLVM::BitcastOp>(
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
131 }
132 }
133
134 // The intrinsic returns i32, f64, and f32 values as individual scalars,
135 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
136 // need to extract them from the struct and pack them into the 64-bit wide
137 // rows of the vector result.
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
141
142 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
143 Value vec =
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
145 Value x1 =
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
148 i * 2 + 1);
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
150 x1, makeConst(0));
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
152 x2, makeConst(1));
153 elements.push_back(vec);
154 }
155 }
156
157 // Create the final vectorized result.
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159 for (const auto &el : llvm::enumerate(elements)) {
160 result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
161 el.index());
162 }
163 return result;
164 }
165
166 return intrinsicResult;
167}
168
169/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
170/// given as 2D `vectors` where the rows are 32b or 64b wide. The
171/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
172/// scalars of certain types. This function helps unpack the `vector` arguments
173/// and cast them to the types expected by `nvvm.mma.sync`.
175 Value operand,
176 NVVM::MMATypes operandPtxType) {
178 Type i32Ty = b.getI32Type();
179 Type f64Ty = b.getF64Type();
180 Type f32Ty = b.getF32Type();
181 Type i64Ty = b.getI64Type();
182 Type i8x4Ty = VectorType::get(4, b.getI8Type());
183 Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
184 Type f32x1Ty = VectorType::get(1, f32Ty);
185 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
186
187 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188 Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
189
190 // For 4xi8 vectors, the intrinsic expects these to be provided as i32
191 // scalar types.
192 if (arrayTy.getElementType() == i8x4Ty ||
193 arrayTy.getElementType() == i4x8Ty ||
194 (arrayTy.getElementType() == f32x1Ty &&
195 operandPtxType == NVVM::MMATypes::tf32)) {
196 result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
197 continue;
198 }
199
200 // For some element types (i32, f32, f64), we need to unpack the inner
201 // vector/array type as well because the intrinsic expects individual
202 // scalars to be provided.
203 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205 innerArrayTy.getElementType() == f64Ty ||
206 innerArrayTy.getElementType() == f32Ty)) {
207 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208 idx < innerSize; idx++) {
209 result.push_back(LLVM::ExtractElementOp::create(
210 b, toUse,
211 LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
212 }
213 continue;
214 }
215 result.push_back(toUse);
216 }
217 return result;
218}
219
220/// Returns whether mbarrier object has shared memory address space.
221static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
222 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223 barrierType.getMemorySpace()));
224}
225
226/// Returns the memory space attribute of the mbarrier object.
228 nvgpu::MBarrierGroupType barrierType) {
229 Attribute memorySpace = {};
230 if (isMbarrierShared(barrierType)) {
231 memorySpace =
232 IntegerAttr::get(IntegerType::get(context, 64),
233 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
234 }
235 return memorySpace;
236}
237
238/// Returns memref type of the mbarrier object. The type is defined in the
239/// MBarrierGroupType.
240MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
241 nvgpu::MBarrierGroupType barrierType) {
242 Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
243 MemRefLayoutAttrInterface layout;
244 return MemRefType::get({barrierType.getNumBarriers()},
245 IntegerType::get(context, 64), layout, memorySpace);
246}
247
248namespace {
249
250struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
251 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
252
253 LogicalResult
254 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter) const override {
256 MLIRContext *ctx = getContext();
257 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
258
259 // The result type of ldmatrix will always be a struct of 32bit integer
260 // registers if more than one 32bit value is returned. Otherwise, the result
261 // is a single i32. The result type of the GPU operation is always a vector
262 // of shape (NumRegisters, VectorRegister) where VectorRegister is the
263 // vector type of the result and always 32 bits long. We bitcast the result
264 // of the NVVM::LdMatrix to this vector type.
265 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266 if (!vectorResultType) {
267 return failure();
268 }
269 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
270 vectorResultType.getElementType());
271
272 int64_t num32BitRegs = vectorResultType.getDimSize(0);
273
274 Type ldMatrixResultType;
275 if (num32BitRegs > 1) {
276 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
277 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
278 } else {
279 ldMatrixResultType = rewriter.getI32Type();
280 }
281
282 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
283 Value srcPtr =
284 getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
285 adaptor.getSrcMemref(), adaptor.getIndices());
286 auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
289 /*num=*/op.getNumTiles(),
290 /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row,
292 /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
293
294 // The ldmatrix operation returns either a single i32 value or a struct of
295 // i32 values. Here we unpack those values and cast them back to their
296 // actual vector type (still of width 32b) and repack them into a result
297 // struct.
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = LLVM::PoisonOp::create(b, finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301 Value i32Register =
302 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
303 : ldMatrixResult;
304 Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
305 result = LLVM::InsertValueOp::create(b, result, casted, i);
306 }
307
308 rewriter.replaceOp(op, result);
309 return success();
310 }
311};
312
313/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314/// enum).
315static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316 Type elType = getElementTypeOrSelf(t);
317 if (elType.isInteger(8))
318 return NVVM::MMATypes::s8;
319 if (elType.isInteger(4))
320 return NVVM::MMATypes::s4;
321 if (elType.isF16())
322 return NVVM::MMATypes::f16;
323 if (elType.isF64())
324 return NVVM::MMATypes::f64;
325 if (elType.isF32())
326 return NVVM::MMATypes::tf32;
327 return failure();
328}
329
330struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
332
333 LogicalResult
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override {
336 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337 // Get the shapes of the MMAMatrix type being used. The shapes will
338 // choose which intrinsic this op will be lowered to.
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
342
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344
345 // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
348 return failure();
349
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351 if (failed(ptxTypeA))
352 return op->emitOpError("failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354 if (failed(ptxTypeB))
355 return op->emitOpError("failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358 /*isAccumulator=*/true);
359 if (!ptxTypeC)
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
362
363 // TODO: add an attribute to the op to customize this behavior.
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
367
368 SmallVector<Value> matA =
369 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370 SmallVector<Value> matB =
371 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372 SmallVector<Value> matC =
373 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376 Type intrinsicResTy = inferIntrinsicResultType(
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult =
379 NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
380 /*shape=*/gemmShape,
381 /*b1Op=*/std::nullopt,
382 /*intOverflow=*/overflow,
383 /*multiplicandPtxTypes=*/
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385 /*multiplicandLayouts=*/
386 std::array<NVVM::MMALayout, 2>{
387 NVVM::MMALayout::row, NVVM::MMALayout::col});
388 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389 desiredRetTy, intrinsicResult,
390 rewriter));
391 return success();
392 }
393};
394
395struct ConvertNVGPUToNVVMPass
396 : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397 using Base::Base;
398
399 void runOnOperation() override {
400 LowerToLLVMOptions options(&getContext());
401 RewritePatternSet patterns(&getContext());
402 LLVMTypeConverter converter(&getContext(), options);
403 IRRewriter rewriter(&getContext());
405 converter, [](gpu::AddressSpace space) -> unsigned {
406 switch (space) {
407 case gpu::AddressSpace::Global:
408 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
409 case gpu::AddressSpace::Workgroup:
410 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
411 case gpu::AddressSpace::Private:
412 return 0;
413 }
414 llvm_unreachable("unknown address space enum value");
415 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
416 });
417 /// device-side async tokens cannot be materialized in nvvm. We just
418 /// convert them to a dummy i32 type in order to easily drop them during
419 /// conversion.
420 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
421 return converter.convertType(IntegerType::get(type.getContext(), 32));
422 });
423 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
424 Type elemType = type.getFragmented().getElementType();
425 int64_t sizeM = type.getFragmented().getDimSize(0);
426 int64_t sizeN = type.getFragmented().getDimSize(1);
427
428 unsigned numMembers;
429 if (elemType.isF32() || elemType.isInteger(32))
430 numMembers = sizeN / 2;
431 else if (elemType.isF16())
432 numMembers = sizeN / 4;
433 else
434 llvm_unreachable("unsupported type for warpgroup accumulator");
435
436 SmallVector<Type> innerStructBody;
437 for (unsigned i = 0; i < numMembers; i++)
438 innerStructBody.push_back(elemType);
439 auto innerStructType =
440 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
441
442 SmallVector<Type> structBody;
443 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
444 structBody.push_back(innerStructType);
445
446 auto convertedType =
447 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
448 return converter.convertType(convertedType);
449 });
450 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
451 return converter.convertType(IntegerType::get(type.getContext(), 64));
452 });
453 converter.addConversion(
454 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
455 return converter.convertType(IntegerType::get(type.getContext(), 64));
456 });
457 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
458 return converter.convertType(
459 nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
460 });
461 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
462 return LLVM::LLVMPointerType::get(type.getContext());
463 });
465 LLVMConversionTarget target(getContext());
466 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
467 target.addLegalDialect<::mlir::arith::ArithDialect>();
468 target.addLegalDialect<::mlir::memref::MemRefDialect>();
469 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
471 converter, patterns, target);
472 if (failed(applyPartialConversion(getOperation(), target,
473 std::move(patterns))))
474 signalPassFailure();
475 }
476};
477
478/// Returns the constraints for the sparse MMA inline assembly instruction.
479static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
480 unsigned matBSize,
481 unsigned matCSize) {
482 std::string str;
483 llvm::raw_string_ostream ss(str);
484 for (unsigned i = 0; i < matCSize; i++)
485 ss << "=r,";
486 for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
487 ss << "r,";
488 // The final operand is for the sparsity metadata.
489 // The sparsity selector appears as direct literal.
490 ss << "r";
491 return str;
492}
493
494/// Returns the string for the `mma.sp.sync` instruction that corresponds to
495/// the given parameters. Note that this function doesn't do any validation,
496/// it's expected that the provided parameters correspond to a valid
497/// instruction.
498static std::string buildMmaSparseAsmString(
499 const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
500 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
501 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
502 std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
503 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
504 return NVVM::stringifyMMATypes(ptxType);
505 };
506
507 std::string asmStr;
508 llvm::raw_string_ostream ss(asmStr);
509 ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
510 << shape[2] << ".row.col.";
511
512 if (overflow)
513 ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
514
515 ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
516 << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
517 unsigned asmArgIdx = 0;
518
519 // The operand string is structured into sections `{matC elements...},
520 // {matA elements...}, {matB elements...}, {matC elements}`.
521 for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
522 ss << "{";
523 for (unsigned i = 0; i < arrSize; i++)
524 ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
525 ss << "},";
526 }
527 ss << "$" << asmArgIdx++ << ",";
528 assert(metaDataSelector <= 1);
529 ss << "0x" << metaDataSelector << ";";
530 return asmStr;
531}
532
533/// Builds an inline assembly operation corresponding to the specified MMA
534/// sparse sync operation.
535static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
536 ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
537 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
538 std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
539 ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
540 int64_t metadataSelector, const std::array<int64_t, 3> &shape,
541 Type intrinsicResultType) {
542 auto asmDialectAttr =
543 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
544
545 const unsigned matASize = unpackedAData.size();
546 const unsigned matBSize = unpackedB.size();
547 const unsigned matCSize = unpackedC.size();
548
549 std::string asmStr = buildMmaSparseAsmString(
550 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
551 ptxTypeD, overflow, metadataSelector);
552 std::string constraintStr =
553 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
554
555 SmallVector<Value> asmVals;
556 asmVals.reserve(matASize + matBSize + matCSize + 1);
557 for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
558 llvm::append_range(asmVals, args);
559 asmVals.push_back(indexData);
560
561 return LLVM::InlineAsmOp::create(b,
562 /*resultTypes=*/intrinsicResultType,
563 /*operands=*/asmVals,
564 /*asm_string=*/asmStr,
565 /*constraints=*/constraintStr,
566 /*has_side_effects=*/true,
567 /*is_align_stack=*/false,
568 LLVM::TailCallKind::None,
569 /*asm_dialect=*/asmDialectAttr,
570 /*operand_attrs=*/ArrayAttr());
571}
572
573/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
574struct NVGPUMmaSparseSyncLowering
575 : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
576 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
577
578 LogicalResult
579 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
580 ConversionPatternRewriter &rewriter) const override {
581 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
582 // Get the shapes of the MMAMatrix type being used. The shapes will
583 // choose which intrinsic this op will be lowered to.
584 VectorType aType = op.getMatrixA().getType();
585 VectorType bType = op.getMatrixB().getType();
586 VectorType cType = op.getMatrixC().getType();
587
588 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
589 if (failed(ptxTypeA))
590 return op->emitOpError("failed to deduce operand PTX types");
591 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
592 if (failed(ptxTypeB))
593 return op->emitOpError("failed to deduce operand PTX types");
594 std::optional<NVVM::MMATypes> ptxTypeC =
595 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
596 /*isAccumulator=*/true);
597 if (!ptxTypeC)
598 return op->emitError(
599 "could not infer the PTX type for the accumulator/result");
600
601 // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
602 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
603 if (aType.getElementType().isF32() && !tf32Enabled)
604 return failure();
605
606 // TODO: add an attribute to the op to customize this behavior.
607 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
608 if (isa<IntegerType>(aType.getElementType()))
609 overflow = NVVM::MMAIntOverflow::satfinite;
610
611 SmallVector<Value> matA =
612 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
613 SmallVector<Value> matB =
614 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
615 SmallVector<Value> matC =
616 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
617
618 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
619 Type intrinsicResTy = inferIntrinsicResultType(
620 typeConverter->convertType(op->getResultTypes()[0]));
621
622 // Bitcast the sparse metadata from vector<2xf16> to an i32.
623 Value sparseMetadata = adaptor.getSparseMetadata();
624 if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
625 return op->emitOpError() << "Expected metadata type to be LLVM "
626 "VectorType of 2 i16 elements";
627 sparseMetadata =
628 LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
629
630 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
631 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
632 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
633 intrinsicResTy);
634 if (failed(intrinsicResult))
635 return failure();
636
637 assert((*intrinsicResult).getNumResults() == 1 &&
638 "expected inline asm op returns a single LLVM struct type");
639 rewriter.replaceOp(
640 op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
641 (*intrinsicResult)->getResult(0), rewriter));
642 return success();
643 }
644};
645
646struct NVGPUAsyncCopyLowering
647 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
648 using ConvertOpToLLVMPattern<
649 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
650
651 LogicalResult
652 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
653 ConversionPatternRewriter &rewriter) const override {
654 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
655 Location loc = op.getLoc();
656 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
657 Value dstPtr =
658 getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
659 adaptor.getDst(), adaptor.getDstIndices());
660 FailureOr<unsigned> dstAddressSpace =
661 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
662 if (failed(dstAddressSpace))
663 return rewriter.notifyMatchFailure(
664 loc, "destination memref address space not convertible to integer");
665
666 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
667 FailureOr<unsigned> srcAddressSpace =
668 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
669 if (failed(srcAddressSpace))
670 return rewriter.notifyMatchFailure(
671 loc, "source memref address space not convertible to integer");
672
673 Value scrPtr =
674 getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
675 adaptor.getSrcIndices());
676 // Intrinsics takes a global pointer so we need an address space cast.
677 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
678 op->getContext(), static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
679 scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
680 int64_t dstElements = adaptor.getDstElements().getZExtValue();
681 int64_t sizeInBytes =
682 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
683 // When the optional SrcElements argument is *not* present, the regular
684 // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
685 // memory) to fill DstElements number of elements in the destination
686 // (shared memory).
687 Value srcBytes = adaptor.getSrcElements();
688 if (srcBytes) {
689 // When the optional SrcElements argument is present, the source (global
690 // memory) of CpAsyncOp is read only for SrcElements number of elements.
691 // The rest of the DstElements in the destination (shared memory) are
692 // filled with zeros.
693 Value c3I32 =
694 LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
695 Value bitwidth = LLVM::ConstantOp::create(
696 b, b.getI32Type(),
697 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
698 Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
699 srcBytes = LLVM::LShrOp::create(
700 b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
701 }
702 // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
703 // 16 dst bytes.
704 NVVM::LoadCacheModifierKind cacheModifier =
705 (op.getBypassL1().value_or(false) && sizeInBytes == 16)
706 ? NVVM::LoadCacheModifierKind::CG
707 : NVVM::LoadCacheModifierKind::CA;
708
709 NVVM::CpAsyncOp::create(
710 b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
711 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
712 srcBytes);
713
714 // Drop the result token.
715 Value zero =
716 LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
717 rewriter.getI32IntegerAttr(0));
718 rewriter.replaceOp(op, zero);
719 return success();
720 }
721};
722
723struct NVGPUAsyncCreateGroupLowering
724 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
725 using ConvertOpToLLVMPattern<
726 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
727
728 LogicalResult
729 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
730 ConversionPatternRewriter &rewriter) const override {
731 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
732 // Drop the result token.
733 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
734 IntegerType::get(op.getContext(), 32),
735 rewriter.getI32IntegerAttr(0));
736 rewriter.replaceOp(op, zero);
737 return success();
738 }
739};
740
741struct NVGPUAsyncWaitLowering
742 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
743 using ConvertOpToLLVMPattern<
744 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
745
746 LogicalResult
747 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter) const override {
749 // If numGroup is not present pick 0 as a conservative correct value.
750 int32_t numGroups = adaptor.getNumGroups().value_or(0);
751 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
752 rewriter.eraseOp(op);
753 return success();
754 }
755};
756
757/// Creates mbarrier object in shared memory
758struct NVGPUMBarrierCreateLowering
759 : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
760 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
761
762 template <typename moduleT>
763 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
764 Operation *funcOp, moduleT moduleOp,
765 MemRefType barrierType) const {
766 SymbolTable symbolTable(moduleOp);
767 OpBuilder::InsertionGuard guard(rewriter);
768 rewriter.setInsertionPoint(&moduleOp.front());
769 auto global = memref::GlobalOp::create(
770 rewriter, funcOp->getLoc(), "__mbarrier",
771 /*sym_visibility=*/rewriter.getStringAttr("private"),
772 /*type=*/barrierType,
773 /*initial_value=*/ElementsAttr(),
774 /*constant=*/false,
775 /*alignment=*/rewriter.getI64IntegerAttr(8));
776 symbolTable.insert(global);
777 return global;
778 }
779
780 LogicalResult
781 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
782 ConversionPatternRewriter &rewriter) const override {
783 Operation *funcOp = op->getParentOp();
784 MemRefType barrierType = nvgpu::getMBarrierMemrefType(
785 rewriter.getContext(), op.getBarriers().getType());
786
787 memref::GlobalOp global;
788 if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
789 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
790 else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
791 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
792
793 rewriter.setInsertionPoint(op);
794 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
795 global.getName());
796 return success();
797 }
798};
799
800/// Base class for lowering mbarrier operations to nvvm intrinsics.
801template <typename SourceOp>
802struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
803public:
804 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
805 /// Returns the base pointer of the mbarrier object.
806 Value getMbarrierPtr(ImplicitLocOpBuilder &b,
807 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
808 Value mbarId,
809 ConversionPatternRewriter &rewriter) const {
810 MemRefType mbarrierMemrefType =
811 nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
813 rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
814 }
815};
816
817struct NVGPUMBarrierGetLowering
818 : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
819 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
820
821 LogicalResult
822 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
823 ConversionPatternRewriter &rewriter) const override {
824 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
825 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
826 rewriter.setInsertionPoint(op);
827 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
828 adaptor.getMbarId(), rewriter);
829 Type resType = op.getMbarrierPointer().getType();
830 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
831 return success();
832 }
833};
834
835/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
836struct NVGPUMBarrierInitLowering
837 : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
838 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
839
840 LogicalResult
841 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
842 ConversionPatternRewriter &rewriter) const override {
843 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
844 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
845 rewriter.setInsertionPoint(op);
846 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
847 adaptor.getMbarId(), rewriter);
848 Value count = truncToI32(b, adaptor.getCount());
849 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
850 adaptor.getPredicate());
851 return success();
852 }
853};
854
855/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
856struct NVGPUMBarrierArriveLowering
857 : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
858 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
859 LogicalResult
860 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter) const override {
862 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
863 Value barrier =
864 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
865 adaptor.getMbarId(), rewriter);
866 Type tokenType = getTypeConverter()->convertType(
867 nvgpu::MBarrierTokenType::get(op->getContext()));
868 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
869 return success();
870 }
871};
872
873/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
874/// `nvvm.mbarrier.arrive.nocomplete`
875struct NVGPUMBarrierArriveNoCompleteLowering
876 : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
877 using MBarrierBasePattern<
878 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
879 LogicalResult
880 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter) const override {
882 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
883 Value barrier =
884 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
885 adaptor.getMbarId(), rewriter);
886 Type tokenType = getTypeConverter()->convertType(
887 nvgpu::MBarrierTokenType::get(op->getContext()));
888 Value count = truncToI32(b, adaptor.getCount());
889 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
890 op, tokenType, barrier, count);
891 return success();
892 }
893};
894
895/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
896struct NVGPUMBarrierTestWaitLowering
897 : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
898 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
899 LogicalResult
900 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter) const override {
902 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
903 Value barrier =
904 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
905 adaptor.getMbarId(), rewriter);
906 Type retType = rewriter.getI1Type();
907 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
908 adaptor.getToken());
909 return success();
910 }
911};
912
913struct NVGPUMBarrierArriveExpectTxLowering
914 : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
915 using MBarrierBasePattern<
916 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
917 LogicalResult
918 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter) const override {
920 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
921 Value barrier =
922 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
923 adaptor.getMbarId(), rewriter);
924 Value txcount = truncToI32(b, adaptor.getTxcount());
925 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
926 op, Type{}, // return-value is optional and is void by default
927 barrier, txcount, // barrier and txcount
928 NVVM::MemScopeKind::CTA, // default scope is CTA
929 false, // relaxed-semantics is false
930 adaptor.getPredicate());
931 return success();
932 }
933};
934
935struct NVGPUMBarrierTryWaitParityLowering
936 : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
937 using MBarrierBasePattern<
938 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
939 LogicalResult
940 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
941 ConversionPatternRewriter &rewriter) const override {
942 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
943 Value barrier =
944 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
945 adaptor.getMbarId(), rewriter);
946 Value ticks = truncToI32(b, adaptor.getTicks());
947 Value phase =
948 LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
949 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
950 phase, ticks);
951 return success();
952 }
953};
954
955struct NVGPUTmaAsyncLoadOpLowering
956 : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
957 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
958 LogicalResult
959 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
960 ConversionPatternRewriter &rewriter) const override {
961 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
962 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
963 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
964 adaptor.getDst(), {});
965 // Intrinsics takes a shared-cluster pointer so we need an
966 // address space cast from 3 to 7.
967 // TODO: Introduce AS(7) in NVGPU.
968 auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
969 op->getContext(),
970 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
971 dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);
972
973 Value barrier =
974 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
975 adaptor.getMbarId(), rewriter);
976
977 SmallVector<Value> coords = adaptor.getCoordinates();
978 for (auto [index, value] : llvm::enumerate(coords)) {
979 coords[index] = truncToI32(b, value);
980 }
981
982 // TODO: Enhance the NVGPU Op for other modes too
983 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
984 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
985 ValueRange{}, adaptor.getMulticastMask(), Value{},
986 NVVM::TMALoadMode::TILE, // default is TILE mode
987 false, // default is cluster-scope
988 nullptr, // default is no cta-group
989 adaptor.getPredicate());
990 return success();
991 }
992};
993
994struct NVGPUTmaAsyncStoreOpLowering
995 : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
996 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
997 LogicalResult
998 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
999 ConversionPatternRewriter &rewriter) const override {
1000 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1001 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1002 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1003 adaptor.getSrc(), {});
1004 SmallVector<Value> coords = adaptor.getCoordinates();
1005 for (auto [index, value] : llvm::enumerate(coords)) {
1006 coords[index] = truncToI32(b, value);
1007 }
1008
1009 // TODO: Enhance the NVGPU Op for other modes too
1010 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1011 op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1012 NVVM::TMAStoreMode::TILE, // default is TILE mode
1013 adaptor.getPredicate());
1014 return success();
1015 }
1016};
1017
1018struct NVGPUGenerateWarpgroupDescriptorLowering
1019 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1020 using ConvertOpToLLVMPattern<
1021 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1022
1023 LogicalResult
1024 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1025 ConversionPatternRewriter &rewriter) const override {
1026
1027 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1028
1029 nvgpu::TensorMapSwizzleKind swizzleKind =
1030 op.getTensorMap().getType().getSwizzle();
1031
1032 unsigned layout =
1033 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1034 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1035 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1036 : 1;
1037 unsigned swizzle =
1038 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1039 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1040 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1041 : 0;
1042
1043 auto ti64 = b.getIntegerType(64);
1044 auto makeConst = [&](uint64_t index) -> Value {
1045 return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
1046 };
1047 auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1048 return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1049 };
1050 auto shiftRight = [&](Value value, unsigned shift) -> Value {
1051 return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1052 };
1053 auto insertBit = [&](Value desc, Value val, int startBit) {
1054 return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1055 };
1056
1057 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1058 uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1059 uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1060 uint64_t offsetVal = 0;
1061
1062 Value strideDim = makeConst(strideDimVal);
1063 Value leadDim = makeConst(leadDimVal);
1064
1065 Value baseAddr = getStridedElementPtr(
1066 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1067 adaptor.getTensor(), {});
1068 Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1069 // Just use 14 bits for base address
1070 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1071
1072 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1073 startLeadBit = 16, startBaseAddrBit = 0;
1074 Value dsc = makeConst(0);
1075 // // [62,64) swizzle type
1076 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1077 // // [49,52) base_offset
1078 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1079 // // [32,46) stride
1080 dsc = insertBit(dsc, strideDim, startStrideBit);
1081 // // [16,30) leading dimension
1082 dsc = insertBit(dsc, leadDim, startLeadBit);
1083 // // [0,14) start_address
1084 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1085
1086 LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
1087 << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
1088 << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
1089 << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1090 << ")\n start_addr : " << baseAddr;
1091
1092 rewriter.replaceOp(op, dsc);
1093 return success();
1094 }
1095};
1096
1097static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1098 return LLVM::ConstantOp::create(b, b.getIntegerType(64),
1099 b.getI32IntegerAttr(index));
1100}
1101
1102/// Returns a Value that holds data type enum that is expected by CUDA driver.
1103static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1104 // Enum is from CUDA driver API
1105 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1106 enum CUtensorMapDataTypeEnum {
1107 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1108 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1109 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1110 CU_TENSOR_MAP_DATA_TYPE_INT32,
1111 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1112 CU_TENSOR_MAP_DATA_TYPE_INT64,
1113 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1114 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1115 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1116 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1117 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1118 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1119 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1120 };
1121
1122 if (type.isUnsignedInteger(8))
1123 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1124 if (type.isUnsignedInteger(16))
1125 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1126 if (type.isUnsignedInteger(32))
1127 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1128 if (type.isUnsignedInteger(64))
1129 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1130 if (type.isSignlessInteger(32))
1131 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1132 if (type.isSignlessInteger(64))
1133 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1134 if (type.isF16())
1135 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1136 if (type.isF32())
1137 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1138 if (type.isF64())
1139 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1140 if (type.isBF16())
1141 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1142
1143 llvm_unreachable("Not supported data type");
1144}
1145
1146struct NVGPUTmaCreateDescriptorOpLowering
1147 : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1148 using ConvertOpToLLVMPattern<
1149 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1150 LogicalResult
1151 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1152 ConversionPatternRewriter &rewriter) const override {
1153 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1154 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1155 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1156
1157 Value tensorElementType =
1158 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1159 auto promotedOperands = getTypeConverter()->promoteOperands(
1160 b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1161
1162 Value boxArrayPtr = LLVM::AllocaOp::create(
1163 b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1164 for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1165 Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1166 boxArrayPtr, makeI64Const(b, index));
1167 LLVM::StoreOp::create(b, value, gep);
1168 }
1169
1170 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1171 // Set Arguments for the function call
1172 SmallVector<Value> arguments;
1173 arguments.push_back(promotedOperands[0]); // rank
1174 arguments.push_back(promotedOperands[1]); // descriptor
1175 arguments.push_back(tensorElementType); // data type
1176 arguments.push_back(
1177 makeI64Const(b, (int)desc.getInterleave())); // interleave
1178 arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1179 arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1180 arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1181 arguments.push_back(boxArrayPtr); // box dimensions
1182
1183 // Set data types of the arguments
1184 SmallVector<Type> argTypes = {
1185 llvmInt64Type, /* int64_t tensorRank */
1186 llvmPointerType, /* ptr */
1187 llvmInt64Type, /* int64_t */
1188 llvmInt64Type, /* int64_t */
1189 llvmInt64Type, /* int64_t */
1190 llvmInt64Type, /* int64_t */
1191 llvmInt64Type, /* int64_t */
1192 llvmPointerType /* ptr */
1193 };
1194 FunctionCallBuilder hostRegisterCallBuilder = {
1195 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1196 Value tensorMap =
1197 hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1198
1199 rewriter.replaceOp(op, tensorMap);
1200 return success();
1201 }
1202};
1203
1204struct NVGPUWarpgroupMmaOpLowering
1205 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1206 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1207
1208 /// This is a helper class to generate required NVVM Ops for warp-group level
1209 /// matrix multiplication.
1210 /// When the given GEMM shape is larger than the shape of
1211 /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1212 /// Op(s), group and execute them asynchronously. The class also handles
1213 /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1214 /// create descriptors for each instruction.
1215 ///
1216 /// For example this is the case when the shape of GEMM is 128x128x128
1217 ///
1218 /// nvvm.wgmma.fence.aligned
1219 ///
1220 /// nvvm.wgmma.mma.async descA, descB
1221 /// iterate(descA, descB)
1222 /// nvvm.wgmma.mma.async descA, descB
1223 /// [6x times more]
1224 ///
1225 /// nvvm.wgmma.group.sync.aligned
1226 /// nvvm.wgmma.wait.group.sync [groupId]
1227 ///
1228 class WarpgroupGemm {
1229 nvgpu::WarpgroupMmaOp op;
1230 ImplicitLocOpBuilder b;
1231 OpAdaptor adaptor;
1232
1233 // Entire shape of the given Op
1234 int64_t totalM, totalN, totalK;
1235
1236 // Shape of one wgmma instruction
1237 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1238
1239 // Iteration counts for GEMM
1240 int iterationM = 0, iterationN = 0, iterationK = 0;
1241
1242 /// The function returns the shape of wgmma instruction that is defined in
1243 /// PTX programming guide.
1244 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1245 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1246 wgmmaM = 64;
1247 wgmmaN = sizeN;
1248 if (inputElemType.isTF32()) {
1249 wgmmaK = 8;
1250 } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1251 wgmmaK = 16;
1252 } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1253 inputElemType.isInteger(16)) {
1254 wgmmaK = 32;
1255 } else if (inputElemType.isInteger(1)) {
1256 wgmmaK = 256;
1257 } else {
1258 llvm_unreachable("msg: not supported K shape");
1259 }
1260 LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1261 << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
1262 }
1263
1264 /// Generates WGMMATypesAttr from MLIR Type
1265 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1266 bool useF32 = false) const {
1267 auto getWgmmaType = [=](Type elemType) {
1268 if (elemType.isF32() || elemType.isTF32())
1269 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1270 if (elemType.isF16())
1271 return NVVM::WGMMATypes::f16;
1272 if (elemType.isBF16())
1273 return NVVM::WGMMATypes::bf16;
1274 if (isa<Float8E4M3FNType>(elemType))
1275 return NVVM::WGMMATypes::e4m3;
1276 if (isa<Float8E5M2Type>(elemType))
1277 return NVVM::WGMMATypes::e5m2;
1278 if (elemType.isInteger(1))
1279 return NVVM::WGMMATypes::b1;
1280 if (elemType.isInteger(8))
1281 return NVVM::WGMMATypes::s8;
1282 if (elemType.isUnsignedInteger(8))
1283 return NVVM::WGMMATypes::u8;
1284 if (elemType.isInteger(32))
1285 return NVVM::WGMMATypes::s32;
1286 llvm_unreachable("unsupported type");
1287 };
1288 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1289 }
1290
1291 /// Generates layout attribute for the input matrix for wgmma instruction
1292 NVVM::MMALayoutAttr
1293 generateWgmmaLayout(std::optional<bool> transpose) const {
1294 if (transpose.value_or(false))
1295 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1296 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1297 }
1298
1299 /// Generates shape attribute for wgmma instruction
1300 NVVM::MMAShapeAttr generateWgmmaShape() const {
1301 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1302 }
1303
1304 /// Generates scale attributes of output matrix for wgmma instruction
1305 NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1306 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1307 NVVM::WGMMAScaleOut::one);
1308 }
1309 /// Generates scale attributes of input matrix for wgmma instruction
1310 NVVM::WGMMAScaleInAttr generateScaleIn() const {
1311 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1312 NVVM::WGMMAScaleIn::one);
1313 }
1314
1315 /// Basic function to generate Add
1316 Value makeAdd(Value lhs, Value rhs) {
1317 return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1318 };
1319
1320 /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1321 /// Currently, it only handles row-major.
1322 ///
1323 /// It moves the pointer like below for [128][64] size:
1324 /// +2 +4 +6
1325 /// ↓ ↓ ↓
1326 /// descA ---> +--+--+--+--+
1327 /// |->|->|->|->|
1328 /// | | | | |
1329 /// | | | | |
1330 /// | | | | |
1331 /// descA+512---> +-----------+
1332 /// | | | | |
1333 /// | | | | |
1334 /// | | | | |
1335 /// | | | | |
1336 /// +-----------+
1337 ///
1338 Value iterateDescriptorA(Value desc, int i, int j, int k) {
1339 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1340 Type elemA = matrixTypeA.getElementType();
1341 int byte = elemA.getIntOrFloatBitWidth() / 8;
1342 int tileShapeA = matrixTypeA.getDimSize(1);
1343 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1344 incrementVal = incrementVal >> exclude4LSB;
1345 LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
1346 << "] [wgmma descriptors] Descriptor A + " << incrementVal
1347 << " | \t ";
1348 if (!incrementVal)
1349 return desc;
1350 return makeAdd(desc, makeI64Const(b, incrementVal));
1351 }
1352
1353 /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1354 /// Currently, it only handles column-major.
1355 ///
1356 /// It moves the pointer like below for [128][64] size:
1357 /// descB ---> +--+--+--+--+--+--+--+--+
1358 /// |↓ | | | | | | | |
1359 /// |↓ | | | | | | | |
1360 /// |↓ | | | | | | | |
1361 /// |↓ | | | | | | | |
1362 /// +--+--+--+--+--+--+--+--+
1363 ///
1364 Value iterateDescriptorB(Value desc, int i, int j, int k) {
1365 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1366 Type elemB = matrixTypeB.getElementType();
1367 int byte = elemB.getIntOrFloatBitWidth() / 8;
1368 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1369 incrementVal = incrementVal >> exclude4LSB;
1370 LDBG() << "Descriptor B + " << incrementVal;
1371 if (!incrementVal)
1372 return desc;
1373 return makeAdd(desc, makeI64Const(b, incrementVal));
1374 }
1375
1376 /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1377 /// descriptors and arranges them based on induction variables: i, j, and k.
1378 Value generateWgmma(int i, int j, int k, Value matrixC) {
1379 LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1380 << "(A[" << (iterationM * wgmmaM) << ":"
1381 << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
1382 << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
1383 << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
1384 << "][" << 0 << ":" << wgmmaN << "])";
1385
1386 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1387 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1388
1389 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1390 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1391
1392 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1393 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1394
1395 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1396 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1397
1398 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1399 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1400 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1401 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1402 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1403
1404 auto overflow = NVVM::MMAIntOverflowAttr::get(
1405 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1406
1407 return NVVM::WgmmaMmaAsyncOp::create(
1408 b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
1409 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1410 overflow);
1411 }
1412
1413 /// Generates multiple wgmma instructions to complete the given GEMM shape
1414 Value generateWgmmaGroup() {
1415 Value wgmmaResult =
1416 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1417
1418 // Perform GEMM
1419 SmallVector<Value> wgmmaResults;
1420 for (int i = 0; i < iterationM; ++i) {
1421 Value matrixC =
1422 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1423 for (int j = 0; j < iterationN; ++j)
1424 for (int k = 0; k < iterationK; ++k)
1425 matrixC = generateWgmma(i, j, k, matrixC);
1426 wgmmaResults.push_back(matrixC);
1427 }
1428 for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1429 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
1430 wgmmaResult, matrix, idx);
1431 }
1432 return wgmmaResult;
1433 }
1434
1435 public:
1436 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1437 OpAdaptor adaptor)
1438 : op(op), b(b), adaptor(adaptor) {
1439 // Find the entire GEMM Shape
1440 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1441 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1442 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1443 LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
1444 << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
1445 << "] ---===";
1446
1447 // Find the shape for one wgmma instruction
1448 findWgmmaShape(
1449 totalM, totalN,
1450 op.getDescriptorA().getType().getTensor().getElementType());
1451
1452 // Iterations counts to complete the given shape with wgmma shape
1453 iterationM = totalM / wgmmaM;
1454 iterationN = totalN / wgmmaN;
1455 iterationK = totalK / wgmmaK;
1456 }
1457
1458 /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1459 /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1460 /// instructions and group synchronization, as well as waiting
1461 /// (WgmmaGroupSyncAlignedOp) for group synchronization
1462 /// (WgmmaWaitGroupSyncOp) after the instructions.
1463 Value generateWarpgroupMma() {
1464 NVVM::WgmmaFenceAlignedOp::create(b);
1465 Value wgmmaResult = generateWgmmaGroup();
1466 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1467 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1468 return wgmmaResult;
1469 }
1470 };
1471 LogicalResult
1472 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1473 ConversionPatternRewriter &rewriter) const override {
1474 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1475
1476 // Step 1. Build a helper class
1477 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1478
1479 // Step 2. Get the entire GEMM Shape
1480 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1481
1482 // Step 3. Replace fragmented result struct with the op results
1483 rewriter.replaceOp(op, wgmmaResult);
1484 return success();
1485 }
1486};
1487
1488struct NVGPUWarpgroupMmaStoreOpLowering
1489 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1490 using ConvertOpToLLVMPattern<
1491 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1492
1493 /// This function stores a fragmented register matrix owned by a warp group
1494 /// (128 threads) into a memref. Each thread has 64 registers, each the size
1495 /// of a struct.
1496 /// Here is what each threads (T) holds, each `d` is struct value with a
1497 /// number.
1498 ///
1499 /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1500 /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1501 /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1502 /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1503 /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1504 ///
1505 /// Matrix-D:
1506 /// +______________________________________________________________________+
1507 /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1508 /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1509 /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1510 /// ..| .........|.........|.........|.........|........|...........|........|
1511 /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1512 /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1513 /// ..| .........|.........|.........|.........|........|...........|........|
1514 /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1515 /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1516 /// ..| .........|.........|.........|.........|........|...........|........|
1517 /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1518 /// ..| .........|.........|.........|.........|........|...........|........|
1519 /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1520 /// ..| .........|.........|.........|.........|........|...........|........|
1521 /// +______________________________________________________________________+
1522 ///
1523 /// \param rewriter: The pattern rewriter.
1524 /// \param matrixD: Result of the warp-group MMA operation (fragmented
1525 /// matrix). It is holded by a thread and a struct with 64 elements.
1526 /// \param dstMemref: The memref where the registers will be stored.
1527 /// \param offset: the offset within the memref where the registers will be
1528 /// stored.
1529 void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1530 TypedValue<MemRefType> dstMemref,
1531 int offset) const {
1532 Type i32 = b.getI32Type();
1533
1534 auto makeConst = [&](int32_t index) -> Value {
1535 return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
1536 };
1537 Value c1 = makeConst(1);
1538 Value c2 = makeConst(2);
1539 Value c4 = makeConst(4);
1540 Value c8 = makeConst(8);
1541 Value c16 = makeConst(16);
1542 Value warpSize = makeConst(kWarpSize);
1543
1544 auto makeMul = [&](Value lhs, Value rhs) -> Value {
1545 return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
1546 };
1547 auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1548 return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1549 };
1550
1551 auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1553 Type it = b.getIndexType();
1554 Value idx = arith::IndexCastOp::create(b, it, x);
1555 Value idy0 = arith::IndexCastOp::create(b, it, y);
1556 Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1557 Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1558 Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1559 memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
1560 memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
1561 };
1562
1563 Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1564 Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1565 Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1566 Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1567 Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1568
1569 Value tj = makeMul(lane4modId, c2);
1570 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1571 if (offset)
1572 ti = makeAdd(ti, makeConst(offset));
1573
1574 auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1575
1576 // Number of 32-bit registers owns per thread
1577 constexpr unsigned numAdjacentRegisters = 2;
1578 // Number of 8x8 matrices one below another per warp
1579 constexpr unsigned numStackedMatrices = 2;
1580
1581 size_t storeCount = (structType.getBody().size() /
1582 (numStackedMatrices * numAdjacentRegisters));
1583
1584 for (size_t i = 0; i < numStackedMatrices; ++i) {
1585 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1586 for (size_t j = 0; j < storeCount; ++j) {
1587 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1588 size_t structIndex = (i * numAdjacentRegisters) +
1589 (j * (numStackedMatrices * numAdjacentRegisters));
1590 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1591 }
1592 }
1593 }
1594
1595 LogicalResult
1596 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1597 ConversionPatternRewriter &rewriter) const override {
1598 int offset = 0;
1599 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1600 Value matriDValue = adaptor.getMatrixD();
1601 auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1602 for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1603 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1604 Value innerStructValue =
1605 LLVM::ExtractValueOp::create(b, matriDValue, idx);
1606 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1607 offset += structType.getBody().size();
1608 }
1609 rewriter.eraseOp(op);
1610 return success();
1611 }
1612};
1613
1614struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1615 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1616 using ConvertOpToLLVMPattern<
1617 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1618 LogicalResult
1619 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1620 ConversionPatternRewriter &rewriter) const override {
1621 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1622 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1623 getTypeConverter()->convertType(op.getMatrixC().getType()));
1624 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1625 .getBody()
1626 .front();
1627 Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
1628 Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1629 SmallVector<Value> innerStructs;
1630 // Unpack the structs and set all values to zero
1631 for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1632 auto structType = cast<LLVM::LLVMStructType>(s);
1633 Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1634 for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1635 structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1636 zero, ArrayRef<int64_t>({i}));
1637 }
1638 innerStructs.push_back(structValue);
1639 }
1640 // Pack the inner structs into a single struct
1641 for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1642 packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
1643 packStruct, matrix, idx);
1644 }
1645 rewriter.replaceOp(op, packStruct);
1646 return success();
1647 }
1648};
1649
1650struct NVGPUTmaFenceOpLowering
1651 : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1652 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1653 LogicalResult
1654 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1655 ConversionPatternRewriter &rewriter) const override {
1656 MLIRContext *ctx = op.getContext();
1657 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1658 auto i32Ty = b.getI32Type();
1659 Value tensormapSize =
1660 LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
1661
1662 auto memscope =
1663 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1664
1665 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1666 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1667
1668 return success();
1669 }
1670};
1671
1672struct NVGPUTmaPrefetchOpLowering
1673 : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1674 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1675 LogicalResult
1676 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1677 ConversionPatternRewriter &rewriter) const override {
1678 rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1679 op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
1680 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1681 /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
1682 return success();
1683 }
1684};
1685
1686struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1687 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1688 LogicalResult
1689 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1690 ConversionPatternRewriter &rewriter) const override {
1691 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1692 auto i64Ty = b.getI64Type();
1693 auto f32Ty = b.getF32Type();
1694 VectorType inTy = op.getIn().getType();
1695 // apply rcp.approx.ftz.f on each element in vector.
1696 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1697 Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1698 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1699 for (int i = 0; i < numElems; i++) {
1700 Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
1701 Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1702 Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1703 ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1704 }
1705 return ret1DVec;
1706 };
1707 if (inTy.getRank() == 1) {
1708 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1709 return success();
1710 }
1712 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1713 [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1714 OpAdaptor adaptor(operands);
1715 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1716 },
1717 rewriter);
1718 }
1719};
1720} // namespace
1721
1723 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1724 patterns.add<
1725 NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1726 NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1727 NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1728 NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1729 NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1730 NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1731 NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1732 NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1733 NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1734 NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1735 NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1736 NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1737 NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1738 NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1739 NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1740 NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1741 NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1742 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1743 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1744 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1745}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
constexpr int kWarpSize
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
FloatType getF32Type()
Definition Builders.cpp:43
IntegerType getI32Type()
Definition Builders.cpp:63
FloatType getF16Type()
Definition Builders.cpp:39
MLIRContext * getContext() const
Definition Builders.h:56
FloatType getF64Type()
Definition Builders.cpp:45
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:64
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isTF32() const
Definition Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const