MLIR  18.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 #define DEBUG_TYPE "nvgpu-to-nvvm"
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define DBGSE() (llvm::dbgs())
36 
37 namespace mlir {
38 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
39 #include "mlir/Conversion/Passes.h.inc"
40 } // namespace mlir
41 
42 using namespace mlir;
43 
44 /// Number of bits that needs to be excluded when building matrix descriptor for
45 /// wgmma operations.
46 constexpr int exclude4LSB = 4;
47 
48 /// GPU has 32 bit registers, this function truncates values when larger width
49 /// is not needed.
51  Type type = value.getType();
52  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
53  if (type.getIntOrFloatBitWidth() <= 32)
54  return value;
55  return b.create<LLVM::TruncOp>(b.getI32Type(), value);
56 }
57 
58 /// Returns the type for the intrinsic given the vectorResultType of the
59 /// `gpu.mma.sync` operation.
60 static Type inferIntrinsicResultType(Type vectorResultType) {
61  MLIRContext *ctx = vectorResultType.getContext();
62  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
63  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
64  auto i32Ty = IntegerType::get(ctx, 32);
65  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
66  Type f64Ty = Float64Type::get(ctx);
67  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
68  Type f32Ty = Float32Type::get(ctx);
69  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
70  if (a.getElementType() == f16x2Ty) {
72  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
73  }
74  if (a.getElementType() == i32x2Ty) {
76  ctx,
77  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
78  }
79  if (a.getElementType() == f64x2Ty) {
80  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81  }
82  if (a.getElementType() == f32x2Ty) {
84  ctx,
85  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
86  }
87  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
89  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
90  }
91  return vectorResultType;
92 }
93 
94 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
95 /// always an LLVM struct) into a fragment that is compatible with the vector
96 /// type of this operation. This involves extracting elements from the struct
97 /// and inserting them into an LLVM array. These extra data-movement
98 /// operations should be canonicalized away by the LLVM backend.
99 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
100  Type resultType, Value intrinsicResult,
101  RewriterBase &rewriter) {
102  MLIRContext *ctx = rewriter.getContext();
103  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
104  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
105  Type i32Ty = rewriter.getI32Type();
106  Type f32Ty = rewriter.getF32Type();
107  Type f64Ty = rewriter.getF64Type();
108  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
109  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
110  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
111  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
112  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
113 
114  auto makeConst = [&](int32_t index) -> Value {
115  return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
116  rewriter.getI32IntegerAttr(index));
117  };
118 
119  if (arrayType) {
120  SmallVector<Value, 4> elements;
121 
122  // The intrinsic returns 32-bit wide elements in a form which can be
123  // directly bitcasted and inserted into the result vector.
124  if (arrayType.getElementType() == f16x2Ty ||
125  arrayType.getElementType() == f32x1Ty) {
126  for (unsigned i = 0; i < structType.getBody().size(); i++) {
127  Value el =
128  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
129  el = rewriter.createOrFold<LLVM::BitcastOp>(
130  loc, arrayType.getElementType(), el);
131  elements.push_back(el);
132  }
133  }
134 
135  // The intrinsic returns i32, f64, and f32 values as individual scalars,
136  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
137  // need to extract them from the struct and pack them into the 64-bit wide
138  // rows of the vector result.
139  if (arrayType.getElementType() == i32x2Ty ||
140  arrayType.getElementType() == f64x2Ty ||
141  arrayType.getElementType() == f32x2Ty) {
142 
143  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144  Value vec =
145  rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
146  Value x1 =
147  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
148  Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
149  i * 2 + 1);
150  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
151  x1, makeConst(0));
152  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
153  x2, makeConst(1));
154  elements.push_back(vec);
155  }
156  }
157 
158  // Create the final vectorized result.
159  Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
160  for (const auto &el : llvm::enumerate(elements)) {
161  result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
162  el.index());
163  }
164  return result;
165  }
166 
167  return intrinsicResult;
168 }
169 
170 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
171 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
172 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
173 /// scalars of certain types. This function helps unpack the `vector` arguments
174 /// and cast them to the types expected by `nvvm.mma.sync`.
176  Value operand,
177  NVVM::MMATypes operandPtxType) {
178  SmallVector<Value> result;
179  Type i32Ty = b.getI32Type();
180  Type f64Ty = b.getF64Type();
181  Type f32Ty = b.getF32Type();
182  Type i64Ty = b.getI64Type();
183  Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
184  Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
185  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
186  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
187 
188  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189  Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
190 
191  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
192  // scalar types.
193  if (arrayTy.getElementType() == i8x4Ty ||
194  arrayTy.getElementType() == i4x8Ty ||
195  (arrayTy.getElementType() == f32x1Ty &&
196  operandPtxType == NVVM::MMATypes::tf32)) {
197  result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
198  continue;
199  }
200 
201  // For some element types (i32, f32, f64), we need to unpack the inner
202  // vector/array type as well because the intrinsic expects individual
203  // scalars to be provided.
204  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
205  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
206  innerArrayTy.getElementType() == f64Ty ||
207  innerArrayTy.getElementType() == f32Ty)) {
208  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
209  idx < innerSize; idx++) {
210  result.push_back(b.create<LLVM::ExtractElementOp>(
211  toUse,
212  b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
213  }
214  continue;
215  }
216  result.push_back(toUse);
217  }
218  return result;
219 }
220 
221 /// Returns whether mbarrier object has shared memory address space.
222 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
223  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
224  barrierType.getMemorySpace()));
225 }
226 
227 /// Returns the memory space attribute of the mbarrier object.
229  nvgpu::MBarrierGroupType barrierType) {
230  Attribute memorySpace = {};
231  if (isMbarrierShared(barrierType)) {
232  memorySpace =
233  IntegerAttr::get(IntegerType::get(context, 64),
234  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
235  }
236  return memorySpace;
237 }
238 
239 /// Returns memref type of the mbarrier object. The type is defined in the
240 /// MBarrierGroupType.
241 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
242  nvgpu::MBarrierGroupType barrierType) {
243  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
244  MemRefLayoutAttrInterface layout;
245  return MemRefType::get({barrierType.getNumBarriers()},
246  IntegerType::get(context, 64), layout, memorySpace);
247 }
248 
249 namespace {
250 
251 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
253 
255  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
256  ConversionPatternRewriter &rewriter) const override {
257  MLIRContext *ctx = getContext();
258  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
259 
260  // The result type of ldmatrix will always be a struct of 32bit integer
261  // registers if more than one 32bit value is returned. Otherwise, the result
262  // is a single i32. The result type of the GPU operation is always a vector
263  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
264  // vector type of the result and always 32 bits long. We bitcast the result
265  // of the NVVM::LdMatrix to this vector type.
266  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
267  if (!vectorResultType) {
268  return failure();
269  }
270  Type innerVectorType = LLVM::getFixedVectorType(
271  vectorResultType.getElementType(), vectorResultType.getDimSize(1));
272 
273  int64_t num32BitRegs = vectorResultType.getDimSize(0);
274 
275  Type ldMatrixResultType;
276  if (num32BitRegs > 1) {
277  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
278  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
279  } else {
280  ldMatrixResultType = rewriter.getI32Type();
281  }
282 
283  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
284  Value srcPtr =
285  getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
286  adaptor.getIndices(), rewriter);
287  Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
288  ldMatrixResultType, srcPtr,
289  /*num=*/op.getNumTiles(),
290  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291  : NVVM::MMALayout::row);
292 
293  // The ldmatrix operation returns either a single i32 value or a struct of
294  // i32 values. Here we unpack those values and cast them back to their
295  // actual vector type (still of width 32b) and repack them into a result
296  // struct.
297  Type finalResultType = typeConverter->convertType(vectorResultType);
298  Value result = b.create<LLVM::UndefOp>(finalResultType);
299  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
300  Value i32Register =
301  num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
302  : ldMatrixResult;
303  Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
304  result = b.create<LLVM::InsertValueOp>(result, casted, i);
305  }
306 
307  rewriter.replaceOp(op, result);
308  return success();
309  }
310 };
311 
312 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
313 /// enum).
314 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
315  Type elType = getElementTypeOrSelf(t);
316  if (elType.isInteger(8))
317  return NVVM::MMATypes::s8;
318  if (elType.isInteger(4))
319  return NVVM::MMATypes::s4;
320  if (elType.isF16())
321  return NVVM::MMATypes::f16;
322  if (elType.isF64())
323  return NVVM::MMATypes::f64;
324  if (elType.isF32())
325  return NVVM::MMATypes::tf32;
326  return failure();
327 }
328 
329 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331 
333  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
334  ConversionPatternRewriter &rewriter) const override {
335  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
336  // Get the shapes of the MMAMatrix type being used. The shapes will
337  // choose which intrinsic this op will be lowered to.
338  VectorType aType = op.getMatrixA().getType();
339  VectorType bType = op.getMatrixA().getType();
340  VectorType cType = op.getMatrixC().getType();
341 
342  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
343 
344  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
345  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
346  if (aType.getElementType().isF32() && !tf32Enabled)
347  return failure();
348 
349  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
350  if (failed(ptxTypeA))
351  return op->emitOpError("failed to deduce operand PTX types");
352  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
353  if (failed(ptxTypeB))
354  return op->emitOpError("failed to deduce operand PTX types");
355  std::optional<NVVM::MMATypes> ptxTypeC =
356  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
357  /*isAccumulator=*/true);
358  if (!ptxTypeC)
359  return op->emitError(
360  "could not infer the PTX type for the accumulator/result");
361 
362  // TODO: add an attribute to the op to customize this behavior.
363  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
364  if (isa<IntegerType>(aType.getElementType()))
365  overflow = NVVM::MMAIntOverflow::satfinite;
366 
367  SmallVector<Value> matA =
368  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
369  SmallVector<Value> matB =
370  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
371  SmallVector<Value> matC =
372  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
373 
374  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
375  Type intrinsicResTy = inferIntrinsicResultType(
376  typeConverter->convertType(op->getResultTypes()[0]));
377  Value intrinsicResult = b.create<NVVM::MmaOp>(
378  intrinsicResTy, matA, matB, matC,
379  /*shape=*/gemmShape,
380  /*b1Op=*/std::nullopt,
381  /*intOverflow=*/overflow,
382  /*multiplicandPtxTypes=*/
383  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
384  /*multiplicandLayouts=*/
385  std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
386  NVVM::MMALayout::col});
387  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
388  desiredRetTy, intrinsicResult,
389  rewriter));
390  return success();
391  }
392 };
393 
394 struct ConvertNVGPUToNVVMPass
395  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
396  using Base::Base;
397 
398  void getDependentDialects(DialectRegistry &registry) const override {
399  registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
400  arith::ArithDialect>();
401  }
402 
403  void runOnOperation() override {
405  RewritePatternSet patterns(&getContext());
406  LLVMTypeConverter converter(&getContext(), options);
407  IRRewriter rewriter(&getContext());
408  /// device-side async tokens cannot be materialized in nvvm. We just
409  /// convert them to a dummy i32 type in order to easily drop them during
410  /// conversion.
411  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
412  return converter.convertType(IntegerType::get(type.getContext(), 32));
413  });
414  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
415  Type elemType = type.getFragmented().getElementType();
416  int64_t sizeM = type.getFragmented().getDimSize(0);
417  int64_t sizeN = type.getFragmented().getDimSize(1);
418 
419  unsigned numMembers;
420  if (elemType.isF32() || elemType.isInteger(32))
421  numMembers = sizeN / 2;
422  else if (elemType.isF16())
423  numMembers = sizeN / 4;
424  else
425  llvm_unreachable("unsupported type for warpgroup accumulator");
426 
427  SmallVector<Type> innerStructBody;
428  for (unsigned i = 0; i < numMembers; i++)
429  innerStructBody.push_back(elemType);
430  auto innerStructType =
431  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
432 
433  SmallVector<Type> structBody;
434  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
435  structBody.push_back(innerStructType);
436 
437  auto convertedType =
438  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
439  return converter.convertType(convertedType);
440  });
441  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
442  return converter.convertType(IntegerType::get(type.getContext(), 64));
443  });
444  converter.addConversion(
445  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
446  return converter.convertType(IntegerType::get(type.getContext(), 64));
447  });
448  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
449  return converter.convertType(
450  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
451  });
452  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
453  return LLVM::LLVMPointerType::get(type.getContext());
454  });
455  populateNVGPUToNVVMConversionPatterns(converter, patterns);
457  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
458  target.addLegalDialect<::mlir::arith::ArithDialect>();
459  target.addLegalDialect<::mlir::memref::MemRefDialect>();
460  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
462  converter, patterns, target);
463  if (failed(applyPartialConversion(getOperation(), target,
464  std::move(patterns))))
465  signalPassFailure();
466  }
467 };
468 
469 /// Returns the constraints for the sparse MMA inline assembly instruction.
470 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
471  unsigned matBSize,
472  unsigned matCSize) {
473  std::string str;
474  llvm::raw_string_ostream ss(str);
475  for (unsigned i = 0; i < matCSize; i++)
476  ss << "=r,";
477  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
478  ss << "r,";
479  // The final operand is for the sparsity metadata.
480  // The sparsity selector appears as direct literal.
481  ss << "r";
482  ss.flush();
483  return str;
484 }
485 
486 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
487 /// the given parameters. Note that this function doesn't do any validation,
488 /// it's expected that the provided parameters correspond to a valid
489 /// instruction.
490 static std::string buildMmaSparseAsmString(
491  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
492  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
493  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
494  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
495  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
496  return NVVM::stringifyMMATypes(ptxType);
497  };
498 
499  std::string asmStr;
500  llvm::raw_string_ostream ss(asmStr);
501  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
502  << shape[2] << ".row.col.";
503 
504  if (overflow)
505  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
506 
507  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
508  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
509  unsigned asmArgIdx = 0;
510 
511  // The operand string is structured into sections `{matC elements...},
512  // {matA elements...}, {matB elements...}, {matC elements}`.
513  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
514  ss << "{";
515  for (unsigned i = 0; i < arrSize; i++)
516  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
517  ss << "},";
518  }
519  ss << "$" << asmArgIdx++ << ",";
520  assert(metaDataSelector <= 1);
521  ss << "0x" << metaDataSelector << ";";
522  ss.flush();
523  return asmStr;
524 }
525 
526 /// Builds an inline assembly operation corresponding to the specified MMA
527 /// sparse sync operation.
528 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
529  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
530  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
531  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
532  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
533  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
534  Type intrinsicResultType) {
535  auto asmDialectAttr =
536  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
537 
538  const unsigned matASize = unpackedAData.size();
539  const unsigned matBSize = unpackedB.size();
540  const unsigned matCSize = unpackedC.size();
541 
542  std::string asmStr = buildMmaSparseAsmString(
543  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
544  ptxTypeD, overflow, metadataSelector);
545  std::string constraintStr =
546  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
547 
548  SmallVector<Value> asmVals;
549  asmVals.reserve(matASize + matBSize + matCSize + 1);
550  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
551  llvm::append_range(asmVals, args);
552  asmVals.push_back(indexData);
553 
554  return b.create<LLVM::InlineAsmOp>(
555  /*resultTypes=*/intrinsicResultType,
556  /*operands=*/asmVals,
557  /*asm_string=*/asmStr,
558  /*constraints=*/constraintStr,
559  /*has_side_effects=*/true,
560  /*is_align_stack=*/false,
561  /*asm_dialect=*/asmDialectAttr,
562  /*operand_attrs=*/ArrayAttr());
563 }
564 
565 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
566 struct NVGPUMmaSparseSyncLowering
567  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
569 
571  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
572  ConversionPatternRewriter &rewriter) const override {
573  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
574  // Get the shapes of the MMAMatrix type being used. The shapes will
575  // choose which intrinsic this op will be lowered to.
576  VectorType aType = op.getMatrixA().getType();
577  VectorType bType = op.getMatrixB().getType();
578  VectorType cType = op.getMatrixC().getType();
579 
580  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
581  if (failed(ptxTypeA))
582  return op->emitOpError("failed to deduce operand PTX types");
583  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
584  if (failed(ptxTypeB))
585  return op->emitOpError("failed to deduce operand PTX types");
586  std::optional<NVVM::MMATypes> ptxTypeC =
587  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
588  /*isAccumulator=*/true);
589  if (!ptxTypeC)
590  return op->emitError(
591  "could not infer the PTX type for the accumulator/result");
592 
593  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
594  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
595  if (aType.getElementType().isF32() && !tf32Enabled)
596  return failure();
597 
598  // TODO: add an attribute to the op to customize this behavior.
599  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
600  if (isa<IntegerType>(aType.getElementType()))
601  overflow = NVVM::MMAIntOverflow::satfinite;
602 
603  SmallVector<Value> matA =
604  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
605  SmallVector<Value> matB =
606  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
607  SmallVector<Value> matC =
608  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
609 
610  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
611  Type intrinsicResTy = inferIntrinsicResultType(
612  typeConverter->convertType(op->getResultTypes()[0]));
613 
614  // Bitcast the sparse metadata from vector<2xf16> to an i32.
615  Value sparseMetadata = adaptor.getSparseMetadata();
616  if (sparseMetadata.getType() !=
617  LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
618  return op->emitOpError() << "Expected metadata type to be LLVM "
619  "VectorType of 2 i16 elements";
620  sparseMetadata =
621  b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
622 
623  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
624  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
625  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
626  intrinsicResTy);
627  if (failed(intrinsicResult))
628  return failure();
629 
630  assert((*intrinsicResult).getNumResults() == 1 &&
631  "expected inline asm op returns a single LLVM struct type");
632  rewriter.replaceOp(
633  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
634  (*intrinsicResult)->getResult(0), rewriter));
635  return success();
636  }
637 };
638 
639 struct NVGPUAsyncCopyLowering
640  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
642  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
643 
645  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
646  ConversionPatternRewriter &rewriter) const override {
647  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
648  Location loc = op.getLoc();
649  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
650  Value dstPtr =
651  getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
652  adaptor.getDstIndices(), rewriter);
653  FailureOr<unsigned> dstAddressSpace =
654  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
655  if (failed(dstAddressSpace))
656  return rewriter.notifyMatchFailure(
657  loc, "destination memref address space not convertible to integer");
658 
659  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
660  FailureOr<unsigned> srcAddressSpace =
661  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
662  if (failed(srcAddressSpace))
663  return rewriter.notifyMatchFailure(
664  loc, "source memref address space not convertible to integer");
665 
666  Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
667  adaptor.getSrcIndices(), rewriter);
668  // Intrinsics takes a global pointer so we need an address space cast.
669  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
671  scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
672  int64_t dstElements = adaptor.getDstElements().getZExtValue();
673  int64_t sizeInBytes =
674  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
675  // When the optional SrcElements argument is *not* present, the regular
676  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
677  // memory) to fill DstElements number of elements in the destination
678  // (shared memory).
679  Value srcBytes = adaptor.getSrcElements();
680  if (srcBytes) {
681  // When the optional SrcElements argument is present, the source (global
682  // memory) of CpAsyncOp is read only for SrcElements number of elements.
683  // The rest of the DstElements in the destination (shared memory) are
684  // filled with zeros.
685  Value c3I32 =
686  b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
687  Value bitwidth = b.create<LLVM::ConstantOp>(
688  b.getI32Type(),
689  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
690  Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
691  srcBytes = b.create<LLVM::LShrOp>(
692  b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
693  }
694  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
695  // 16 dst bytes.
696  NVVM::LoadCacheModifierKind cacheModifier =
697  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
698  ? NVVM::LoadCacheModifierKind::CG
699  : NVVM::LoadCacheModifierKind::CA;
700 
701  b.create<NVVM::CpAsyncOp>(
702  dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
703  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
704  srcBytes);
705 
706  // Drop the result token.
707  Value zero = b.create<LLVM::ConstantOp>(
708  IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
709  rewriter.replaceOp(op, zero);
710  return success();
711  }
712 };
713 
714 struct NVGPUAsyncCreateGroupLowering
715  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
717  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
718 
720  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
721  ConversionPatternRewriter &rewriter) const override {
722  rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
723  // Drop the result token.
724  Value zero = rewriter.create<LLVM::ConstantOp>(
725  op->getLoc(), IntegerType::get(op.getContext(), 32),
726  rewriter.getI32IntegerAttr(0));
727  rewriter.replaceOp(op, zero);
728  return success();
729  }
730 };
731 
732 struct NVGPUAsyncWaitLowering
733  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
735  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
736 
738  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
739  ConversionPatternRewriter &rewriter) const override {
740  // If numGroup is not present pick 0 as a conservative correct value.
741  int32_t numGroups = adaptor.getNumGroups().value_or(0);
742  rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
743  rewriter.eraseOp(op);
744  return success();
745  }
746 };
747 
748 /// Creates mbarrier object in shared memory
749 struct NVGPUMBarrierCreateLowering
750  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
752 
753  template <typename moduleT>
754  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
755  Operation *funcOp, moduleT moduleOp,
756  MemRefType barrierType) const {
757  SymbolTable symbolTable(moduleOp);
758  OpBuilder::InsertionGuard guard(rewriter);
759  rewriter.setInsertionPoint(&moduleOp.front());
760  auto global = rewriter.create<memref::GlobalOp>(
761  funcOp->getLoc(), "__mbarrier",
762  /*sym_visibility=*/rewriter.getStringAttr("private"),
763  /*type=*/barrierType,
764  /*initial_value=*/ElementsAttr(),
765  /*constant=*/false,
766  /*alignment=*/rewriter.getI64IntegerAttr(8));
767  symbolTable.insert(global);
768  return global;
769  }
770 
772  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
773  ConversionPatternRewriter &rewriter) const override {
774  Operation *funcOp = op->getParentOp();
775  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
776  rewriter.getContext(), op.getBarriers().getType());
777 
778  memref::GlobalOp global;
779  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
780  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
781  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
782  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
783 
784  rewriter.setInsertionPoint(op);
785  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
786  global.getName());
787  return success();
788  }
789 };
790 
791 /// Base class for lowering mbarrier operations to nvvm intrinsics.
792 template <typename SourceOp>
793 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
794 public:
796  /// Returns the base pointer of the mbarrier object.
797  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
798  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
799  Value mbarId,
800  ConversionPatternRewriter &rewriter) const {
801  MemRefType mbarrierMemrefType =
802  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
804  b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
805  }
806 };
807 
808 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
809 struct NVGPUMBarrierInitLowering
810  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
811  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
812 
814  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
815  ConversionPatternRewriter &rewriter) const override {
816  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
817  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
818  rewriter.setInsertionPoint(op);
819  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
820  adaptor.getMbarId(), rewriter);
821  Value count = truncToI32(b, adaptor.getCount());
822  if (isMbarrierShared(mbarrierType)) {
823  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
824  op, barrier, count, adaptor.getPredicate());
825  } else {
826  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
827  adaptor.getPredicate());
828  }
829  return success();
830  }
831 };
832 
833 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
834 struct NVGPUMBarrierArriveLowering
835  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
836  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
838  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
839  ConversionPatternRewriter &rewriter) const override {
840  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
841  Value barrier =
842  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
843  adaptor.getMbarId(), rewriter);
844  Type tokenType = getTypeConverter()->convertType(
846  if (isMbarrierShared(op.getBarriers().getType())) {
847  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
848  barrier);
849  } else {
850  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
851  barrier);
852  }
853  return success();
854  }
855 };
856 
857 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
858 /// `nvvm.mbarrier.arrive.nocomplete`
859 struct NVGPUMBarrierArriveNoCompleteLowering
860  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
861  using MBarrierBasePattern<
862  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
864  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
865  ConversionPatternRewriter &rewriter) const override {
866  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
867  Value barrier =
868  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
869  adaptor.getMbarId(), rewriter);
870  Type tokenType = getTypeConverter()->convertType(
872  Value count = truncToI32(b, adaptor.getCount());
873  if (isMbarrierShared(op.getBarriers().getType())) {
874  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
875  op, tokenType, barrier, count);
876  } else {
877  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
878  op, tokenType, barrier, count);
879  }
880  return success();
881  }
882 };
883 
884 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
885 struct NVGPUMBarrierTestWaitLowering
886  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
887  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
889  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
890  ConversionPatternRewriter &rewriter) const override {
891  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
892  Value barrier =
893  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
894  adaptor.getMbarId(), rewriter);
895  Type retType = rewriter.getI1Type();
896  if (isMbarrierShared(op.getBarriers().getType())) {
897  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
898  op, retType, barrier, adaptor.getToken());
899  } else {
900  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
901  op, retType, barrier, adaptor.getToken());
902  }
903  return success();
904  }
905 };
906 
907 struct NVGPUMBarrierArriveExpectTxLowering
908  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
909  using MBarrierBasePattern<
910  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
912  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
913  ConversionPatternRewriter &rewriter) const override {
914  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
915  Value barrier =
916  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
917  adaptor.getMbarId(), rewriter);
918  Value txcount = truncToI32(b, adaptor.getTxcount());
919 
920  if (isMbarrierShared(op.getBarriers().getType())) {
921  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
922  op, barrier, txcount, adaptor.getPredicate());
923  return success();
924  }
925 
926  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
927  op, barrier, txcount, adaptor.getPredicate());
928  return success();
929  }
930 };
931 
932 struct NVGPUMBarrierTryWaitParityLowering
933  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
934  using MBarrierBasePattern<
935  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
937  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
938  ConversionPatternRewriter &rewriter) const override {
939  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
940  Value barrier =
941  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
942  adaptor.getMbarId(), rewriter);
943  Value ticks = truncToI32(b, adaptor.getTicks());
944  Value phase = truncToI32(b, adaptor.getPhase());
945 
946  if (isMbarrierShared(op.getBarriers().getType())) {
947  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
948  op, barrier, phase, ticks);
949  return success();
950  }
951 
952  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
953  phase, ticks);
954  return success();
955  }
956 };
957 
958 struct NVGPUTmaAsyncLoadOpLowering
959  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
960  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
962  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
963  ConversionPatternRewriter &rewriter) const override {
964  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
965  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
966  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
967  adaptor.getDst(), {}, rewriter);
968  Value barrier =
969  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
970  adaptor.getMbarId(), rewriter);
971 
972  SmallVector<Value> coords = adaptor.getCoordinates();
973  for (auto [index, value] : llvm::enumerate(coords)) {
974  coords[index] = truncToI32(b, value);
975  }
976  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
977  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
978  ValueRange{}, Value{}, Value{}, adaptor.getPredicate());
979  return success();
980  }
981 };
982 struct NVGPUGenerateWarpgroupDescriptorLowering
983  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
985  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
986 
988  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
989  ConversionPatternRewriter &rewriter) const override {
990 
991  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
992 
993  nvgpu::TensorMapSwizzleKind swizzleKind =
994  op.getTensorMap().getType().getSwizzle();
995 
996  unsigned layout =
997  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
998  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
999  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1000  : 1;
1001  unsigned swizzle =
1002  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1003  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1004  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1005  : 0;
1006 
1007  auto ti64 = b.getIntegerType(64);
1008  auto makeConst = [&](uint64_t index) -> Value {
1009  return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1010  };
1011  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1012  return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1013  };
1014  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1015  return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1016  };
1017  auto insertBit = [&](Value desc, Value val, int startBit) {
1018  return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1019  };
1020 
1021  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1022  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1023  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1024  uint64_t offsetVal = 0;
1025 
1026  Value strideDim = makeConst(strideDimVal);
1027  Value leadDim = makeConst(leadDimVal);
1028 
1029  Value baseAddr = getStridedElementPtr(
1030  op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1031  adaptor.getTensor(), {}, rewriter);
1032  Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1033  // Just use 14 bits for base address
1034  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1035 
1036  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1037  startLeadBit = 16, startBaseAddrBit = 0;
1038  Value dsc = makeConst(0);
1039  // // [62,64) swizzle type
1040  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1041  // // [49,52) base_offset
1042  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1043  // // [32,46) stride
1044  dsc = insertBit(dsc, strideDim, startStrideBit);
1045  // // [16,30) leading dimension
1046  dsc = insertBit(dsc, leadDim, startLeadBit);
1047  // // [0,14) start_address
1048  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1049 
1050  LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1051  << "leading_off:" << leadDimVal << "\t"
1052  << "stride_off :" << strideDimVal << "\t"
1053  << "base_offset:" << offsetVal << "\t"
1054  << "layout_type:" << swizzle << " ("
1055  << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1056  << ")\n start_addr : " << baseAddr << "\n");
1057 
1058  rewriter.replaceOp(op, dsc);
1059  return success();
1060  }
1061 };
1062 
1063 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1064  return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1065  b.getI32IntegerAttr(index));
1066 }
1067 
1068 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1069 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1070  // Enum is from CUDA driver API
1071  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1072  enum CUtensorMapDataTypeEnum {
1073  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1074  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1075  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1076  CU_TENSOR_MAP_DATA_TYPE_INT32,
1077  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1078  CU_TENSOR_MAP_DATA_TYPE_INT64,
1079  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1080  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1081  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1082  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1083  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1084  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1085  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1086  };
1087 
1088  if (type.isUnsignedInteger(8))
1089  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1090  if (type.isUnsignedInteger(16))
1091  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1092  if (type.isUnsignedInteger(32))
1093  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1094  if (type.isUnsignedInteger(64))
1095  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1096  if (type.isSignlessInteger(32))
1097  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1098  if (type.isSignlessInteger(64))
1099  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1100  if (type.isF16())
1101  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1102  if (type.isF32())
1103  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1104  if (type.isF64())
1105  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1106  if (type.isBF16())
1107  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1108 
1109  llvm_unreachable("Not supported data type");
1110 }
1111 
1112 struct NVGPUTmaCreateDescriptorOpLowering
1113  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1114  using ConvertOpToLLVMPattern<
1115  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1117  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1118  ConversionPatternRewriter &rewriter) const override {
1119  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1120  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1121  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1122 
1123  Value tensorElementType =
1124  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1125  auto promotedOperands = getTypeConverter()->promoteOperands(
1126  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1127 
1128  Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1129  makeI64Const(b, 5));
1130  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1131  Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1132  boxArrayPtr, makeI64Const(b, index));
1133  b.create<LLVM::StoreOp>(value, gep);
1134  }
1135 
1136  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1137  // Set Arguments for the function call
1138  SmallVector<Value> arguments;
1139  arguments.push_back(promotedOperands[0]); // rank
1140  arguments.push_back(promotedOperands[1]); // descriptor
1141  arguments.push_back(tensorElementType); // data type
1142  arguments.push_back(
1143  makeI64Const(b, (int)desc.getInterleave())); // interleave
1144  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1145  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1146  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1147  arguments.push_back(boxArrayPtr); // box dimensions
1148 
1149  // Set data types of the arguments
1150  SmallVector<Type> argTypes = {
1151  llvmInt64Type, /* int64_t tensorRank */
1152  llvmPointerType, /* ptr */
1153  llvmInt64Type, /* int64_t */
1154  llvmInt64Type, /* int64_t */
1155  llvmInt64Type, /* int64_t */
1156  llvmInt64Type, /* int64_t */
1157  llvmInt64Type, /* int64_t */
1158  llvmPointerType /* ptr */
1159  };
1160  FunctionCallBuilder hostRegisterCallBuilder = {
1161  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1162  Value tensorMap =
1163  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1164 
1165  rewriter.replaceOp(op, tensorMap);
1166  return success();
1167  }
1168 };
1169 
1170 struct NVGPUWarpgroupMmaOpLowering
1171  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1173 
1174  /// This is a helper class to generate required NVVM Ops for warp-group level
1175  /// matrix multiplication.
1176  /// When the given GEMM shape is larger than the shape of
1177  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1178  /// Op(s), group and execute them asynchronously. The class also handles
1179  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1180  /// create descriptors for each instruction.
1181  ///
1182  /// For example this is the case when the shape of GEMM is 128x128x128
1183  ///
1184  /// nvvm.wgmma.fence.aligned
1185  ///
1186  /// nvvm.wgmma.mma.async descA, descB
1187  /// iterate(descA, descB)
1188  /// nvvm.wgmma.mma.async descA, descB
1189  /// [6x times more]
1190  ///
1191  /// nvvm.wgmma.group.sync.aligned
1192  /// nvvm.wgmma.wait.group.sync [groupId]
1193  ///
1194  class WarpgroupGemm {
1195  nvgpu::WarpgroupMmaOp op;
1197  OpAdaptor adaptor;
1198 
1199  // Entire shape of the given Op
1200  int64_t totalM, totalN, totalK;
1201 
1202  // Shape of one wgmma instruction
1203  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1204 
1205  // Iteration counts for GEMM
1206  int iterationM = 0, iterationN = 0, iterationK = 0;
1207 
1208  /// The function returns the shape of wgmma instruction that is defined in
1209  /// PTX programming guide.
1210  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1211  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1212  wgmmaM = 64;
1213  wgmmaN = sizeN;
1214  if (inputElemType.isTF32()) {
1215  wgmmaK = 8;
1216  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1217  wgmmaK = 16;
1218  } else if (inputElemType.isFloat8E4M3FN() ||
1219  inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
1220  wgmmaK = 32;
1221  } else if (inputElemType.isInteger(1)) {
1222  wgmmaK = 256;
1223  } else {
1224  llvm_unreachable("msg: not supported K shape");
1225  }
1226  LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1227  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1228  }
1229 
1230  /// Generates WGMMATypesAttr from MLIR Type
1231  NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
1232  auto getWgmmaType = [](Type elemType) {
1233  if (elemType.isF32() || elemType.isTF32())
1234  return NVVM::WGMMATypes::tf32;
1235  if (elemType.isF16())
1236  return NVVM::WGMMATypes::f16;
1237  if (elemType.isBF16())
1238  return NVVM::WGMMATypes::bf16;
1239  if (elemType.isFloat8E4M3FN())
1240  return NVVM::WGMMATypes::e4m3;
1241  if (elemType.isFloat8E5M2())
1242  return NVVM::WGMMATypes::e5m2;
1243  if (elemType.isInteger(1))
1244  return NVVM::WGMMATypes::b1;
1245  if (elemType.isInteger(8))
1246  return NVVM::WGMMATypes::s8;
1247  if (elemType.isUnsignedInteger(8))
1248  return NVVM::WGMMATypes::u8;
1249  llvm_unreachable("unsupported type");
1250  };
1251  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1252  }
1253 
1254  /// Generates layout attribute for the input matrix for wgmma instruction
1255  NVVM::MMALayoutAttr
1256  generateWgmmaLayout(std::optional<bool> transpose) const {
1257  if (transpose.value_or(false))
1258  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1259  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1260  }
1261 
1262  /// Generates shape attribute for wgmma instruction
1263  NVVM::MMAShapeAttr generateWgmmaShape() const {
1264  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1265  }
1266 
1267  /// Generates scale attributes of output matrix for wgmma instruction
1268  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1270  NVVM::WGMMAScaleOut::one);
1271  }
1272  /// Generates scale attributes of input matrix for wgmma instruction
1273  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1275  NVVM::WGMMAScaleIn::one);
1276  }
1277 
1278  /// Basic function to generate Add
1279  Value makeAdd(Value lhs, Value rhs) {
1280  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1281  };
1282 
1283  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1284  /// Currently, it only handles row-major.
1285  ///
1286  /// It moves the pointer like below for [128][64] size:
1287  /// +2 +4 +6
1288  /// ↓ ↓ ↓
1289  /// descA ---> +--+--+--+--+
1290  /// |->|->|->|->|
1291  /// | | | | |
1292  /// | | | | |
1293  /// | | | | |
1294  /// descA+512---> +-----------+
1295  /// | | | | |
1296  /// | | | | |
1297  /// | | | | |
1298  /// | | | | |
1299  /// +-----------+
1300  ///
1301  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1302  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1303  Type elemA = matrixTypeA.getElementType();
1304  int byte = elemA.getIntOrFloatBitWidth() / 8;
1305  int tileShapeA = matrixTypeA.getDimSize(1);
1306  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1307  incrementVal = incrementVal >> exclude4LSB;
1308  LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1309  << "] [wgmma descriptors] Descriptor A + "
1310  << incrementVal << " | \t ");
1311  if (!incrementVal)
1312  return desc;
1313  return makeAdd(desc, makeI64Const(b, incrementVal));
1314  }
1315 
1316  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1317  /// Currently, it only handles column-major.
1318  ///
1319  /// It moves the pointer like below for [128][64] size:
1320  /// descB ---> +--+--+--+--+--+--+--+--+
1321  /// |↓ | | | | | | | |
1322  /// |↓ | | | | | | | |
1323  /// |↓ | | | | | | | |
1324  /// |↓ | | | | | | | |
1325  /// +--+--+--+--+--+--+--+--+
1326  ///
1327  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1328  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1329  Type elemB = matrixTypeB.getElementType();
1330  int byte = elemB.getIntOrFloatBitWidth() / 8;
1331  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1332  incrementVal = incrementVal >> exclude4LSB;
1333  LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1334  if (!incrementVal)
1335  return desc;
1336  return makeAdd(desc, makeI64Const(b, incrementVal));
1337  }
1338 
1339  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1340  /// descriptors and arranges them based on induction variables: i, j, and k.
1341  Value generateWgmma(int i, int j, int k, Value matrixC) {
1342  LLVM_DEBUG(DBGS() << "\t wgmma."
1343  << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1344  << "(A[" << (iterationM * wgmmaM) << ":"
1345  << (iterationM * wgmmaM) + wgmmaM << "]["
1346  << (iterationK * wgmmaK) << ":"
1347  << (iterationK * wgmmaK + wgmmaK) << "] * "
1348  << " B[" << (iterationK * wgmmaK) << ":"
1349  << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1350  << wgmmaN << "])\n");
1351 
1352  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1353  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1354 
1355  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1356  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1357 
1358  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1359  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1360 
1361  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1362  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1363  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1364  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1365  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
1366 
1367  auto overflow = NVVM::MMAIntOverflowAttr::get(
1368  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1369 
1370  return b.create<NVVM::WgmmaMmaAsyncOp>(
1371  matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1372  itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1373  }
1374 
1375  /// Generates multiple wgmma instructions to complete the given GEMM shape
1376  Value generateWgmmaGroup() {
1377  Value wgmmaResult =
1378  b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1379 
1380  // Perform GEMM
1381  SmallVector<Value> wgmmaResults;
1382  for (int i = 0; i < iterationM; ++i) {
1383  Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1384  for (int j = 0; j < iterationN; ++j)
1385  for (int k = 0; k < iterationK; ++k)
1386  matrixC = generateWgmma(i, j, k, matrixC);
1387  wgmmaResults.push_back(matrixC);
1388  }
1389  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1390  wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1391  wgmmaResult, matrix, idx);
1392  }
1393  return wgmmaResult;
1394  }
1395 
1396  public:
1397  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1398  OpAdaptor adaptor)
1399  : op(op), b(b), adaptor(adaptor) {
1400  // Find the entire GEMM Shape
1401  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1402  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1403  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1404  LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1405  << "] += A[" << totalM << "][" << totalK << "] * B["
1406  << totalK << "][" << totalN << "] ---===\n");
1407 
1408  // Find the shape for one wgmma instruction
1409  findWgmmaShape(
1410  totalM, totalN,
1411  op.getDescriptorA().getType().getTensor().getElementType());
1412 
1413  // Iterations counts to complete the given shape with wgmma shape
1414  iterationM = totalM / wgmmaM;
1415  iterationN = totalN / wgmmaN;
1416  iterationK = totalK / wgmmaK;
1417  }
1418 
1419  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1420  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1421  /// instructions and group synchronization, as well as waiting
1422  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1423  /// (WgmmaWaitGroupSyncOp) after the instructions.
1424  Value generateWarpgroupMma() {
1425  b.create<NVVM::WgmmaFenceAlignedOp>();
1426  Value wgmmaResult = generateWgmmaGroup();
1427  b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1428  b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1429  return wgmmaResult;
1430  }
1431  };
1433  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1434  ConversionPatternRewriter &rewriter) const override {
1435  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1436 
1437  // Step 1. Build a helper class
1438  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1439 
1440  // Step 2. Get the entire GEMM Shape
1441  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1442 
1443  // Step 3. Replace fragmented result struct with the op results
1444  rewriter.replaceOp(op, wgmmaResult);
1445  return success();
1446  }
1447 };
1448 
1449 struct NVGPUWarpgroupMmaStoreOpLowering
1450  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1451  using ConvertOpToLLVMPattern<
1452  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1453 
1454  /// This function stores a fragmented register matrix owned by a warp group
1455  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1456  /// of a struct.
1457  /// Here is what each threads (T) holds, each `d` is struct value with a
1458  /// number.
1459  ///
1460  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1461  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1462  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1463  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1464  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1465  ///
1466  /// Matrix-D:
1467  /// +______________________________________________________________________+
1468  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1469  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1470  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1471  /// ..| .........|.........|.........|.........|........|...........|........|
1472  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1473  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1474  /// ..| .........|.........|.........|.........|........|...........|........|
1475  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1476  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1477  /// ..| .........|.........|.........|.........|........|...........|........|
1478  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1479  /// ..| .........|.........|.........|.........|........|...........|........|
1480  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1481  /// ..| .........|.........|.........|.........|........|...........|........|
1482  /// +______________________________________________________________________+
1483  ///
1484  /// \param rewriter: The pattern rewriter.
1485  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1486  /// matrix). It is holded by a thread and a struct with 64 elements.
1487  /// \param dstMemref: The memref where the registers will be stored.
1488  /// \param offset: the offset within the memref where the registers will be
1489  /// stored.
1490  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1491  TypedValue<MemRefType> dstMemref,
1492  int offset) const {
1493  Type i32 = b.getI32Type();
1494 
1495  auto makeConst = [&](int32_t index) -> Value {
1496  return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1497  };
1498  Value c1 = makeConst(1);
1499  Value c2 = makeConst(2);
1500  Value c4 = makeConst(4);
1501  Value c8 = makeConst(8);
1502  Value c16 = makeConst(16);
1503  Value warpSize = makeConst(kWarpSize);
1504 
1505  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1506  return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1507  };
1508  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1509  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1510  };
1511 
1512  Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1513  Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1514  Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1515  Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1516  Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1517 
1518  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1520  Type it = b.getIndexType();
1521  Value idx = b.create<arith::IndexCastOp>(it, x);
1522  Value idy0 = b.create<arith::IndexCastOp>(it, y);
1523  Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1524  Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1525  Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1526  b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1527  b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1528  };
1529 
1530  Value tj = makeMul(lane4modId, c2);
1531  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1532  if (offset)
1533  ti = makeAdd(ti, makeConst(offset));
1534  for (int i = 0; i < 2; ++i) {
1535  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1536  for (int j = 0; j < 16; ++j) {
1537  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1538  int sIndex = i * 2 + j * 4;
1539  makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
1540  }
1541  }
1542  }
1543 
1545  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1546  ConversionPatternRewriter &rewriter) const override {
1547  int offset = 0;
1548  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1549  Value matriDValue = adaptor.getMatrixD();
1550  auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
1551  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1552  auto structType = matrixD.cast<LLVM::LLVMStructType>();
1553  Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1554  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1555  offset += structType.getBody().size();
1556  }
1557  rewriter.eraseOp(op);
1558  return success();
1559  }
1560 };
1561 
1562 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1563  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1564  using ConvertOpToLLVMPattern<
1565  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1567  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1568  ConversionPatternRewriter &rewriter) const override {
1569  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1570  LLVM::LLVMStructType packStructType =
1571  getTypeConverter()
1572  ->convertType(op.getMatrixC().getType())
1573  .cast<LLVM::LLVMStructType>();
1574  Type elemType = packStructType.getBody()
1575  .front()
1576  .cast<LLVM::LLVMStructType>()
1577  .getBody()
1578  .front();
1579  Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1580  Value packStruct = b.create<LLVM::UndefOp>(packStructType);
1581  SmallVector<Value> innerStructs;
1582  // Unpack the structs and set all values to zero
1583  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1584  auto structType = s.cast<LLVM::LLVMStructType>();
1585  Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1586  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1587  structValue = b.create<LLVM::InsertValueOp>(
1588  structType, structValue, zero, ArrayRef<int64_t>({i}));
1589  }
1590  innerStructs.push_back(structValue);
1591  }
1592  // Pack the inner structs into a single struct
1593  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1594  packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1595  packStruct, matrix, idx);
1596  }
1597  rewriter.replaceOp(op, packStruct);
1598  return success();
1599  }
1600 };
1601 
1602 struct NVGPUTmaPrefetchOpLowering
1603  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1606  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1607  ConversionPatternRewriter &rewriter) const override {
1608  rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1609  op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1610  return success();
1611  }
1612 };
1613 
1614 } // namespace
1615 
1617  RewritePatternSet &patterns) {
1618  patterns.add<
1619  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1620  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1621  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1622  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1623  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1624  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1625  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1626  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1627  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1628  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1629  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1630  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1631  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1632  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1633  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1634  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1635  NVGPUMmaSparseSyncLowering>(converter);
1636 }
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:27
constexpr int kWarpSize
Definition: NVGPUDialect.h:24
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:50
#define DBGSE()
Definition: NVGPUToNVVM.cpp:35
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:60
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:46
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
#define DBGS()
Definition: NVGPUToNVVM.cpp:34
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
Definition: NVGPUToNVVM.cpp:99
static llvm::ManagedStatic< PassManagerOptions > options
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:81
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
FloatType getF32Type()
Definition: Builders.cpp:63
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
FloatType getF16Type()
Definition: Builders.cpp:59
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
IntegerType getI8Type()
Definition: Builders.cpp:79
FloatType getF64Type()
Definition: Builders.cpp:65
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:109
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:473
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:436
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:538
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
bool isTF32() const
Definition: Types.cpp:50
U cast() const
Definition: Types.h:339
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isFloat8E4M3FN() const
Definition: Types.cpp:38
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:65
bool isF32() const
Definition: Types.cpp:51
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:89
bool isFloat8E5M2() const
Definition: Types.cpp:37
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
bool isBF16() const
Definition: Types.cpp:48
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
U cast() const
Definition: Value.h:116
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:942
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.