MLIR  19.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 #define DEBUG_TYPE "nvgpu-to-nvvm"
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define DBGSE() (llvm::dbgs())
36 
37 namespace mlir {
38 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
39 #include "mlir/Conversion/Passes.h.inc"
40 } // namespace mlir
41 
42 using namespace mlir;
43 
44 /// Number of bits that needs to be excluded when building matrix descriptor for
45 /// wgmma operations.
46 constexpr int exclude4LSB = 4;
47 
48 /// GPU has 32 bit registers, this function truncates values when larger width
49 /// is not needed.
51  Type type = value.getType();
52  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
53  if (type.getIntOrFloatBitWidth() <= 32)
54  return value;
55  return b.create<LLVM::TruncOp>(b.getI32Type(), value);
56 }
57 
58 /// Returns the type for the intrinsic given the vectorResultType of the
59 /// `gpu.mma.sync` operation.
60 static Type inferIntrinsicResultType(Type vectorResultType) {
61  MLIRContext *ctx = vectorResultType.getContext();
62  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
63  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
64  auto i32Ty = IntegerType::get(ctx, 32);
65  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
66  Type f64Ty = Float64Type::get(ctx);
67  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
68  Type f32Ty = Float32Type::get(ctx);
69  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
70  if (a.getElementType() == f16x2Ty) {
72  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
73  }
74  if (a.getElementType() == i32x2Ty) {
76  ctx,
77  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
78  }
79  if (a.getElementType() == f64x2Ty) {
80  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81  }
82  if (a.getElementType() == f32x2Ty) {
84  ctx,
85  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
86  }
87  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
89  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
90  }
91  return vectorResultType;
92 }
93 
94 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
95 /// always an LLVM struct) into a fragment that is compatible with the vector
96 /// type of this operation. This involves extracting elements from the struct
97 /// and inserting them into an LLVM array. These extra data-movement
98 /// operations should be canonicalized away by the LLVM backend.
99 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
100  Type resultType, Value intrinsicResult,
101  RewriterBase &rewriter) {
102  MLIRContext *ctx = rewriter.getContext();
103  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
104  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
105  Type i32Ty = rewriter.getI32Type();
106  Type f32Ty = rewriter.getF32Type();
107  Type f64Ty = rewriter.getF64Type();
108  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
109  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
110  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
111  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
112  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
113 
114  auto makeConst = [&](int32_t index) -> Value {
115  return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
116  rewriter.getI32IntegerAttr(index));
117  };
118 
119  if (arrayType) {
120  SmallVector<Value, 4> elements;
121 
122  // The intrinsic returns 32-bit wide elements in a form which can be
123  // directly bitcasted and inserted into the result vector.
124  if (arrayType.getElementType() == f16x2Ty ||
125  arrayType.getElementType() == f32x1Ty) {
126  for (unsigned i = 0; i < structType.getBody().size(); i++) {
127  Value el =
128  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
129  el = rewriter.createOrFold<LLVM::BitcastOp>(
130  loc, arrayType.getElementType(), el);
131  elements.push_back(el);
132  }
133  }
134 
135  // The intrinsic returns i32, f64, and f32 values as individual scalars,
136  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
137  // need to extract them from the struct and pack them into the 64-bit wide
138  // rows of the vector result.
139  if (arrayType.getElementType() == i32x2Ty ||
140  arrayType.getElementType() == f64x2Ty ||
141  arrayType.getElementType() == f32x2Ty) {
142 
143  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144  Value vec =
145  rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
146  Value x1 =
147  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
148  Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
149  i * 2 + 1);
150  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
151  x1, makeConst(0));
152  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
153  x2, makeConst(1));
154  elements.push_back(vec);
155  }
156  }
157 
158  // Create the final vectorized result.
159  Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
160  for (const auto &el : llvm::enumerate(elements)) {
161  result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
162  el.index());
163  }
164  return result;
165  }
166 
167  return intrinsicResult;
168 }
169 
170 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
171 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
172 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
173 /// scalars of certain types. This function helps unpack the `vector` arguments
174 /// and cast them to the types expected by `nvvm.mma.sync`.
176  Value operand,
177  NVVM::MMATypes operandPtxType) {
178  SmallVector<Value> result;
179  Type i32Ty = b.getI32Type();
180  Type f64Ty = b.getF64Type();
181  Type f32Ty = b.getF32Type();
182  Type i64Ty = b.getI64Type();
183  Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
184  Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
185  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
186  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
187 
188  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189  Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
190 
191  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
192  // scalar types.
193  if (arrayTy.getElementType() == i8x4Ty ||
194  arrayTy.getElementType() == i4x8Ty ||
195  (arrayTy.getElementType() == f32x1Ty &&
196  operandPtxType == NVVM::MMATypes::tf32)) {
197  result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
198  continue;
199  }
200 
201  // For some element types (i32, f32, f64), we need to unpack the inner
202  // vector/array type as well because the intrinsic expects individual
203  // scalars to be provided.
204  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
205  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
206  innerArrayTy.getElementType() == f64Ty ||
207  innerArrayTy.getElementType() == f32Ty)) {
208  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
209  idx < innerSize; idx++) {
210  result.push_back(b.create<LLVM::ExtractElementOp>(
211  toUse,
212  b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
213  }
214  continue;
215  }
216  result.push_back(toUse);
217  }
218  return result;
219 }
220 
221 /// Returns whether mbarrier object has shared memory address space.
222 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
223  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
224  barrierType.getMemorySpace()));
225 }
226 
227 /// Returns the memory space attribute of the mbarrier object.
229  nvgpu::MBarrierGroupType barrierType) {
230  Attribute memorySpace = {};
231  if (isMbarrierShared(barrierType)) {
232  memorySpace =
233  IntegerAttr::get(IntegerType::get(context, 64),
234  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
235  }
236  return memorySpace;
237 }
238 
239 /// Returns memref type of the mbarrier object. The type is defined in the
240 /// MBarrierGroupType.
241 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
242  nvgpu::MBarrierGroupType barrierType) {
243  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
244  MemRefLayoutAttrInterface layout;
245  return MemRefType::get({barrierType.getNumBarriers()},
246  IntegerType::get(context, 64), layout, memorySpace);
247 }
248 
249 namespace {
250 
251 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
253 
255  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
256  ConversionPatternRewriter &rewriter) const override {
257  MLIRContext *ctx = getContext();
258  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
259 
260  // The result type of ldmatrix will always be a struct of 32bit integer
261  // registers if more than one 32bit value is returned. Otherwise, the result
262  // is a single i32. The result type of the GPU operation is always a vector
263  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
264  // vector type of the result and always 32 bits long. We bitcast the result
265  // of the NVVM::LdMatrix to this vector type.
266  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
267  if (!vectorResultType) {
268  return failure();
269  }
270  Type innerVectorType = LLVM::getFixedVectorType(
271  vectorResultType.getElementType(), vectorResultType.getDimSize(1));
272 
273  int64_t num32BitRegs = vectorResultType.getDimSize(0);
274 
275  Type ldMatrixResultType;
276  if (num32BitRegs > 1) {
277  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
278  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
279  } else {
280  ldMatrixResultType = rewriter.getI32Type();
281  }
282 
283  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
284  Value srcPtr =
285  getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
286  adaptor.getIndices(), rewriter);
287  Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
288  ldMatrixResultType, srcPtr,
289  /*num=*/op.getNumTiles(),
290  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291  : NVVM::MMALayout::row);
292 
293  // The ldmatrix operation returns either a single i32 value or a struct of
294  // i32 values. Here we unpack those values and cast them back to their
295  // actual vector type (still of width 32b) and repack them into a result
296  // struct.
297  Type finalResultType = typeConverter->convertType(vectorResultType);
298  Value result = b.create<LLVM::UndefOp>(finalResultType);
299  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
300  Value i32Register =
301  num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
302  : ldMatrixResult;
303  Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
304  result = b.create<LLVM::InsertValueOp>(result, casted, i);
305  }
306 
307  rewriter.replaceOp(op, result);
308  return success();
309  }
310 };
311 
312 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
313 /// enum).
314 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
315  Type elType = getElementTypeOrSelf(t);
316  if (elType.isInteger(8))
317  return NVVM::MMATypes::s8;
318  if (elType.isInteger(4))
319  return NVVM::MMATypes::s4;
320  if (elType.isF16())
321  return NVVM::MMATypes::f16;
322  if (elType.isF64())
323  return NVVM::MMATypes::f64;
324  if (elType.isF32())
325  return NVVM::MMATypes::tf32;
326  return failure();
327 }
328 
329 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331 
333  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
334  ConversionPatternRewriter &rewriter) const override {
335  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
336  // Get the shapes of the MMAMatrix type being used. The shapes will
337  // choose which intrinsic this op will be lowered to.
338  VectorType aType = op.getMatrixA().getType();
339  VectorType bType = op.getMatrixA().getType();
340  VectorType cType = op.getMatrixC().getType();
341 
342  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
343 
344  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
345  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
346  if (aType.getElementType().isF32() && !tf32Enabled)
347  return failure();
348 
349  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
350  if (failed(ptxTypeA))
351  return op->emitOpError("failed to deduce operand PTX types");
352  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
353  if (failed(ptxTypeB))
354  return op->emitOpError("failed to deduce operand PTX types");
355  std::optional<NVVM::MMATypes> ptxTypeC =
356  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
357  /*isAccumulator=*/true);
358  if (!ptxTypeC)
359  return op->emitError(
360  "could not infer the PTX type for the accumulator/result");
361 
362  // TODO: add an attribute to the op to customize this behavior.
363  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
364  if (isa<IntegerType>(aType.getElementType()))
365  overflow = NVVM::MMAIntOverflow::satfinite;
366 
367  SmallVector<Value> matA =
368  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
369  SmallVector<Value> matB =
370  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
371  SmallVector<Value> matC =
372  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
373 
374  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
375  Type intrinsicResTy = inferIntrinsicResultType(
376  typeConverter->convertType(op->getResultTypes()[0]));
377  Value intrinsicResult = b.create<NVVM::MmaOp>(
378  intrinsicResTy, matA, matB, matC,
379  /*shape=*/gemmShape,
380  /*b1Op=*/std::nullopt,
381  /*intOverflow=*/overflow,
382  /*multiplicandPtxTypes=*/
383  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
384  /*multiplicandLayouts=*/
385  std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
386  NVVM::MMALayout::col});
387  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
388  desiredRetTy, intrinsicResult,
389  rewriter));
390  return success();
391  }
392 };
393 
394 struct ConvertNVGPUToNVVMPass
395  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
396  using Base::Base;
397 
398  void getDependentDialects(DialectRegistry &registry) const override {
399  registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
400  arith::ArithDialect>();
401  }
402 
403  void runOnOperation() override {
405  RewritePatternSet patterns(&getContext());
406  LLVMTypeConverter converter(&getContext(), options);
407  IRRewriter rewriter(&getContext());
409  converter, [](gpu::AddressSpace space) -> unsigned {
410  switch (space) {
411  case gpu::AddressSpace::Global:
412  return static_cast<unsigned>(
414  case gpu::AddressSpace::Workgroup:
415  return static_cast<unsigned>(
417  case gpu::AddressSpace::Private:
418  return 0;
419  }
420  llvm_unreachable("unknown address space enum value");
421  return 0;
422  });
423  /// device-side async tokens cannot be materialized in nvvm. We just
424  /// convert them to a dummy i32 type in order to easily drop them during
425  /// conversion.
426  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
427  return converter.convertType(IntegerType::get(type.getContext(), 32));
428  });
429  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
430  Type elemType = type.getFragmented().getElementType();
431  int64_t sizeM = type.getFragmented().getDimSize(0);
432  int64_t sizeN = type.getFragmented().getDimSize(1);
433 
434  unsigned numMembers;
435  if (elemType.isF32() || elemType.isInteger(32))
436  numMembers = sizeN / 2;
437  else if (elemType.isF16())
438  numMembers = sizeN / 4;
439  else
440  llvm_unreachable("unsupported type for warpgroup accumulator");
441 
442  SmallVector<Type> innerStructBody;
443  for (unsigned i = 0; i < numMembers; i++)
444  innerStructBody.push_back(elemType);
445  auto innerStructType =
446  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
447 
448  SmallVector<Type> structBody;
449  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
450  structBody.push_back(innerStructType);
451 
452  auto convertedType =
453  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
454  return converter.convertType(convertedType);
455  });
456  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
457  return converter.convertType(IntegerType::get(type.getContext(), 64));
458  });
459  converter.addConversion(
460  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
461  return converter.convertType(IntegerType::get(type.getContext(), 64));
462  });
463  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
464  return converter.convertType(
465  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
466  });
467  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
468  return LLVM::LLVMPointerType::get(type.getContext());
469  });
470  populateNVGPUToNVVMConversionPatterns(converter, patterns);
472  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
473  target.addLegalDialect<::mlir::arith::ArithDialect>();
474  target.addLegalDialect<::mlir::memref::MemRefDialect>();
475  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
477  converter, patterns, target);
478  if (failed(applyPartialConversion(getOperation(), target,
479  std::move(patterns))))
480  signalPassFailure();
481  }
482 };
483 
484 /// Returns the constraints for the sparse MMA inline assembly instruction.
485 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
486  unsigned matBSize,
487  unsigned matCSize) {
488  std::string str;
489  llvm::raw_string_ostream ss(str);
490  for (unsigned i = 0; i < matCSize; i++)
491  ss << "=r,";
492  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
493  ss << "r,";
494  // The final operand is for the sparsity metadata.
495  // The sparsity selector appears as direct literal.
496  ss << "r";
497  ss.flush();
498  return str;
499 }
500 
501 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
502 /// the given parameters. Note that this function doesn't do any validation,
503 /// it's expected that the provided parameters correspond to a valid
504 /// instruction.
505 static std::string buildMmaSparseAsmString(
506  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
507  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
510  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511  return NVVM::stringifyMMATypes(ptxType);
512  };
513 
514  std::string asmStr;
515  llvm::raw_string_ostream ss(asmStr);
516  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
517  << shape[2] << ".row.col.";
518 
519  if (overflow)
520  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
521 
522  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
523  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
524  unsigned asmArgIdx = 0;
525 
526  // The operand string is structured into sections `{matC elements...},
527  // {matA elements...}, {matB elements...}, {matC elements}`.
528  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529  ss << "{";
530  for (unsigned i = 0; i < arrSize; i++)
531  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
532  ss << "},";
533  }
534  ss << "$" << asmArgIdx++ << ",";
535  assert(metaDataSelector <= 1);
536  ss << "0x" << metaDataSelector << ";";
537  ss.flush();
538  return asmStr;
539 }
540 
541 /// Builds an inline assembly operation corresponding to the specified MMA
542 /// sparse sync operation.
543 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
544  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
545  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
546  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
547  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
548  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
549  Type intrinsicResultType) {
550  auto asmDialectAttr =
551  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
552 
553  const unsigned matASize = unpackedAData.size();
554  const unsigned matBSize = unpackedB.size();
555  const unsigned matCSize = unpackedC.size();
556 
557  std::string asmStr = buildMmaSparseAsmString(
558  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
559  ptxTypeD, overflow, metadataSelector);
560  std::string constraintStr =
561  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
562 
563  SmallVector<Value> asmVals;
564  asmVals.reserve(matASize + matBSize + matCSize + 1);
565  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
566  llvm::append_range(asmVals, args);
567  asmVals.push_back(indexData);
568 
569  return b.create<LLVM::InlineAsmOp>(
570  /*resultTypes=*/intrinsicResultType,
571  /*operands=*/asmVals,
572  /*asm_string=*/asmStr,
573  /*constraints=*/constraintStr,
574  /*has_side_effects=*/true,
575  /*is_align_stack=*/false,
576  /*asm_dialect=*/asmDialectAttr,
577  /*operand_attrs=*/ArrayAttr());
578 }
579 
580 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
581 struct NVGPUMmaSparseSyncLowering
582  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
584 
586  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
587  ConversionPatternRewriter &rewriter) const override {
588  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
589  // Get the shapes of the MMAMatrix type being used. The shapes will
590  // choose which intrinsic this op will be lowered to.
591  VectorType aType = op.getMatrixA().getType();
592  VectorType bType = op.getMatrixB().getType();
593  VectorType cType = op.getMatrixC().getType();
594 
595  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
596  if (failed(ptxTypeA))
597  return op->emitOpError("failed to deduce operand PTX types");
598  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
599  if (failed(ptxTypeB))
600  return op->emitOpError("failed to deduce operand PTX types");
601  std::optional<NVVM::MMATypes> ptxTypeC =
602  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
603  /*isAccumulator=*/true);
604  if (!ptxTypeC)
605  return op->emitError(
606  "could not infer the PTX type for the accumulator/result");
607 
608  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
609  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
610  if (aType.getElementType().isF32() && !tf32Enabled)
611  return failure();
612 
613  // TODO: add an attribute to the op to customize this behavior.
614  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
615  if (isa<IntegerType>(aType.getElementType()))
616  overflow = NVVM::MMAIntOverflow::satfinite;
617 
618  SmallVector<Value> matA =
619  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
620  SmallVector<Value> matB =
621  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
622  SmallVector<Value> matC =
623  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
624 
625  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
626  Type intrinsicResTy = inferIntrinsicResultType(
627  typeConverter->convertType(op->getResultTypes()[0]));
628 
629  // Bitcast the sparse metadata from vector<2xf16> to an i32.
630  Value sparseMetadata = adaptor.getSparseMetadata();
631  if (sparseMetadata.getType() !=
632  LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
633  return op->emitOpError() << "Expected metadata type to be LLVM "
634  "VectorType of 2 i16 elements";
635  sparseMetadata =
636  b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
637 
638  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
639  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
640  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
641  intrinsicResTy);
642  if (failed(intrinsicResult))
643  return failure();
644 
645  assert((*intrinsicResult).getNumResults() == 1 &&
646  "expected inline asm op returns a single LLVM struct type");
647  rewriter.replaceOp(
648  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
649  (*intrinsicResult)->getResult(0), rewriter));
650  return success();
651  }
652 };
653 
654 struct NVGPUAsyncCopyLowering
655  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
657  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
658 
660  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
661  ConversionPatternRewriter &rewriter) const override {
662  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
663  Location loc = op.getLoc();
664  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
665  Value dstPtr =
666  getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
667  adaptor.getDstIndices(), rewriter);
668  FailureOr<unsigned> dstAddressSpace =
669  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
670  if (failed(dstAddressSpace))
671  return rewriter.notifyMatchFailure(
672  loc, "destination memref address space not convertible to integer");
673 
674  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
675  FailureOr<unsigned> srcAddressSpace =
676  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
677  if (failed(srcAddressSpace))
678  return rewriter.notifyMatchFailure(
679  loc, "source memref address space not convertible to integer");
680 
681  Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
682  adaptor.getSrcIndices(), rewriter);
683  // Intrinsics takes a global pointer so we need an address space cast.
684  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
686  scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
687  int64_t dstElements = adaptor.getDstElements().getZExtValue();
688  int64_t sizeInBytes =
689  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
690  // When the optional SrcElements argument is *not* present, the regular
691  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
692  // memory) to fill DstElements number of elements in the destination
693  // (shared memory).
694  Value srcBytes = adaptor.getSrcElements();
695  if (srcBytes) {
696  // When the optional SrcElements argument is present, the source (global
697  // memory) of CpAsyncOp is read only for SrcElements number of elements.
698  // The rest of the DstElements in the destination (shared memory) are
699  // filled with zeros.
700  Value c3I32 =
701  b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
702  Value bitwidth = b.create<LLVM::ConstantOp>(
703  b.getI32Type(),
704  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
705  Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
706  srcBytes = b.create<LLVM::LShrOp>(
707  b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
708  }
709  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
710  // 16 dst bytes.
711  NVVM::LoadCacheModifierKind cacheModifier =
712  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
713  ? NVVM::LoadCacheModifierKind::CG
714  : NVVM::LoadCacheModifierKind::CA;
715 
716  b.create<NVVM::CpAsyncOp>(
717  dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
718  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
719  srcBytes);
720 
721  // Drop the result token.
722  Value zero = b.create<LLVM::ConstantOp>(
723  IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
724  rewriter.replaceOp(op, zero);
725  return success();
726  }
727 };
728 
729 struct NVGPUAsyncCreateGroupLowering
730  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
732  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
733 
735  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
736  ConversionPatternRewriter &rewriter) const override {
737  rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
738  // Drop the result token.
739  Value zero = rewriter.create<LLVM::ConstantOp>(
740  op->getLoc(), IntegerType::get(op.getContext(), 32),
741  rewriter.getI32IntegerAttr(0));
742  rewriter.replaceOp(op, zero);
743  return success();
744  }
745 };
746 
747 struct NVGPUAsyncWaitLowering
748  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
750  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
751 
753  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
754  ConversionPatternRewriter &rewriter) const override {
755  // If numGroup is not present pick 0 as a conservative correct value.
756  int32_t numGroups = adaptor.getNumGroups().value_or(0);
757  rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
758  rewriter.eraseOp(op);
759  return success();
760  }
761 };
762 
763 /// Creates mbarrier object in shared memory
764 struct NVGPUMBarrierCreateLowering
765  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
767 
768  template <typename moduleT>
769  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
770  Operation *funcOp, moduleT moduleOp,
771  MemRefType barrierType) const {
772  SymbolTable symbolTable(moduleOp);
773  OpBuilder::InsertionGuard guard(rewriter);
774  rewriter.setInsertionPoint(&moduleOp.front());
775  auto global = rewriter.create<memref::GlobalOp>(
776  funcOp->getLoc(), "__mbarrier",
777  /*sym_visibility=*/rewriter.getStringAttr("private"),
778  /*type=*/barrierType,
779  /*initial_value=*/ElementsAttr(),
780  /*constant=*/false,
781  /*alignment=*/rewriter.getI64IntegerAttr(8));
782  symbolTable.insert(global);
783  return global;
784  }
785 
787  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
788  ConversionPatternRewriter &rewriter) const override {
789  Operation *funcOp = op->getParentOp();
790  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
791  rewriter.getContext(), op.getBarriers().getType());
792 
793  memref::GlobalOp global;
794  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
795  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
796  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
797  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
798 
799  rewriter.setInsertionPoint(op);
800  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
801  global.getName());
802  return success();
803  }
804 };
805 
806 /// Base class for lowering mbarrier operations to nvvm intrinsics.
807 template <typename SourceOp>
808 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
809 public:
811  /// Returns the base pointer of the mbarrier object.
812  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
813  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
814  Value mbarId,
815  ConversionPatternRewriter &rewriter) const {
816  MemRefType mbarrierMemrefType =
817  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
819  b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
820  }
821 };
822 
823 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
824 struct NVGPUMBarrierInitLowering
825  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
826  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
827 
829  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
830  ConversionPatternRewriter &rewriter) const override {
831  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
832  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
833  rewriter.setInsertionPoint(op);
834  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
835  adaptor.getMbarId(), rewriter);
836  Value count = truncToI32(b, adaptor.getCount());
837  if (isMbarrierShared(mbarrierType)) {
838  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
839  op, barrier, count, adaptor.getPredicate());
840  } else {
841  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
842  adaptor.getPredicate());
843  }
844  return success();
845  }
846 };
847 
848 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
849 struct NVGPUMBarrierArriveLowering
850  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
851  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
853  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
854  ConversionPatternRewriter &rewriter) const override {
855  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
856  Value barrier =
857  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
858  adaptor.getMbarId(), rewriter);
859  Type tokenType = getTypeConverter()->convertType(
861  if (isMbarrierShared(op.getBarriers().getType())) {
862  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
863  barrier);
864  } else {
865  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
866  barrier);
867  }
868  return success();
869  }
870 };
871 
872 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
873 /// `nvvm.mbarrier.arrive.nocomplete`
874 struct NVGPUMBarrierArriveNoCompleteLowering
875  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
876  using MBarrierBasePattern<
877  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
879  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
880  ConversionPatternRewriter &rewriter) const override {
881  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
882  Value barrier =
883  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
884  adaptor.getMbarId(), rewriter);
885  Type tokenType = getTypeConverter()->convertType(
887  Value count = truncToI32(b, adaptor.getCount());
888  if (isMbarrierShared(op.getBarriers().getType())) {
889  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
890  op, tokenType, barrier, count);
891  } else {
892  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
893  op, tokenType, barrier, count);
894  }
895  return success();
896  }
897 };
898 
899 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
900 struct NVGPUMBarrierTestWaitLowering
901  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
902  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
904  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
905  ConversionPatternRewriter &rewriter) const override {
906  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
907  Value barrier =
908  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
909  adaptor.getMbarId(), rewriter);
910  Type retType = rewriter.getI1Type();
911  if (isMbarrierShared(op.getBarriers().getType())) {
912  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
913  op, retType, barrier, adaptor.getToken());
914  } else {
915  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
916  op, retType, barrier, adaptor.getToken());
917  }
918  return success();
919  }
920 };
921 
922 struct NVGPUMBarrierArriveExpectTxLowering
923  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
924  using MBarrierBasePattern<
925  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
927  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
928  ConversionPatternRewriter &rewriter) const override {
929  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
930  Value barrier =
931  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
932  adaptor.getMbarId(), rewriter);
933  Value txcount = truncToI32(b, adaptor.getTxcount());
934 
935  if (isMbarrierShared(op.getBarriers().getType())) {
936  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
937  op, barrier, txcount, adaptor.getPredicate());
938  return success();
939  }
940 
941  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
942  op, barrier, txcount, adaptor.getPredicate());
943  return success();
944  }
945 };
946 
947 struct NVGPUMBarrierTryWaitParityLowering
948  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
949  using MBarrierBasePattern<
950  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
952  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
953  ConversionPatternRewriter &rewriter) const override {
954  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
955  Value barrier =
956  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
957  adaptor.getMbarId(), rewriter);
958  Value ticks = truncToI32(b, adaptor.getTicks());
959  Value phase =
960  b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
961 
962  if (isMbarrierShared(op.getBarriers().getType())) {
963  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
964  op, barrier, phase, ticks);
965  return success();
966  }
967 
968  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
969  phase, ticks);
970  return success();
971  }
972 };
973 
974 struct NVGPUTmaAsyncLoadOpLowering
975  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
976  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
978  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
979  ConversionPatternRewriter &rewriter) const override {
980  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
981  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
982  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
983  adaptor.getDst(), {}, rewriter);
984  Value barrier =
985  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
986  adaptor.getMbarId(), rewriter);
987 
988  SmallVector<Value> coords = adaptor.getCoordinates();
989  for (auto [index, value] : llvm::enumerate(coords)) {
990  coords[index] = truncToI32(b, value);
991  }
992  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
993  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
994  ValueRange{}, adaptor.getMulticastMask(), Value{},
995  adaptor.getPredicate());
996  return success();
997  }
998 };
999 
1000 struct NVGPUTmaAsyncStoreOpLowering
1001  : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1002  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1004  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1005  ConversionPatternRewriter &rewriter) const override {
1006  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1007  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1008  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1009  adaptor.getSrc(), {}, rewriter);
1010  SmallVector<Value> coords = adaptor.getCoordinates();
1011  for (auto [index, value] : llvm::enumerate(coords)) {
1012  coords[index] = truncToI32(b, value);
1013  }
1014 
1015  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1016  op, adaptor.getTensorMapDescriptor(), dest, coords,
1017  adaptor.getPredicate());
1018  return success();
1019  }
1020 };
1021 
1022 struct NVGPUGenerateWarpgroupDescriptorLowering
1023  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1024  using ConvertOpToLLVMPattern<
1025  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1026 
1028  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1029  ConversionPatternRewriter &rewriter) const override {
1030 
1031  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1032 
1033  nvgpu::TensorMapSwizzleKind swizzleKind =
1034  op.getTensorMap().getType().getSwizzle();
1035 
1036  unsigned layout =
1037  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1038  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1039  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1040  : 1;
1041  unsigned swizzle =
1042  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1043  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1044  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1045  : 0;
1046 
1047  auto ti64 = b.getIntegerType(64);
1048  auto makeConst = [&](uint64_t index) -> Value {
1049  return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1050  };
1051  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1052  return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1053  };
1054  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1055  return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1056  };
1057  auto insertBit = [&](Value desc, Value val, int startBit) {
1058  return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1059  };
1060 
1061  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1062  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1063  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1064  uint64_t offsetVal = 0;
1065 
1066  Value strideDim = makeConst(strideDimVal);
1067  Value leadDim = makeConst(leadDimVal);
1068 
1069  Value baseAddr = getStridedElementPtr(
1070  op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1071  adaptor.getTensor(), {}, rewriter);
1072  Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1073  // Just use 14 bits for base address
1074  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1075 
1076  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1077  startLeadBit = 16, startBaseAddrBit = 0;
1078  Value dsc = makeConst(0);
1079  // // [62,64) swizzle type
1080  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1081  // // [49,52) base_offset
1082  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1083  // // [32,46) stride
1084  dsc = insertBit(dsc, strideDim, startStrideBit);
1085  // // [16,30) leading dimension
1086  dsc = insertBit(dsc, leadDim, startLeadBit);
1087  // // [0,14) start_address
1088  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1089 
1090  LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1091  << "leading_off:" << leadDimVal << "\t"
1092  << "stride_off :" << strideDimVal << "\t"
1093  << "base_offset:" << offsetVal << "\t"
1094  << "layout_type:" << swizzle << " ("
1095  << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1096  << ")\n start_addr : " << baseAddr << "\n");
1097 
1098  rewriter.replaceOp(op, dsc);
1099  return success();
1100  }
1101 };
1102 
1103 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1104  return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1105  b.getI32IntegerAttr(index));
1106 }
1107 
1108 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1109 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1110  // Enum is from CUDA driver API
1111  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1112  enum CUtensorMapDataTypeEnum {
1113  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1114  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1115  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1116  CU_TENSOR_MAP_DATA_TYPE_INT32,
1117  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1118  CU_TENSOR_MAP_DATA_TYPE_INT64,
1119  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1120  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1121  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1122  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1123  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1124  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1125  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1126  };
1127 
1128  if (type.isUnsignedInteger(8))
1129  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1130  if (type.isUnsignedInteger(16))
1131  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1132  if (type.isUnsignedInteger(32))
1133  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1134  if (type.isUnsignedInteger(64))
1135  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1136  if (type.isSignlessInteger(32))
1137  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1138  if (type.isSignlessInteger(64))
1139  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1140  if (type.isF16())
1141  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1142  if (type.isF32())
1143  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1144  if (type.isF64())
1145  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1146  if (type.isBF16())
1147  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1148 
1149  llvm_unreachable("Not supported data type");
1150 }
1151 
1152 struct NVGPUTmaCreateDescriptorOpLowering
1153  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1154  using ConvertOpToLLVMPattern<
1155  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1157  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1158  ConversionPatternRewriter &rewriter) const override {
1159  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1160  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1161  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1162 
1163  Value tensorElementType =
1164  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1165  auto promotedOperands = getTypeConverter()->promoteOperands(
1166  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1167 
1168  Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1169  makeI64Const(b, 5));
1170  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1171  Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1172  boxArrayPtr, makeI64Const(b, index));
1173  b.create<LLVM::StoreOp>(value, gep);
1174  }
1175 
1176  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1177  // Set Arguments for the function call
1178  SmallVector<Value> arguments;
1179  arguments.push_back(promotedOperands[0]); // rank
1180  arguments.push_back(promotedOperands[1]); // descriptor
1181  arguments.push_back(tensorElementType); // data type
1182  arguments.push_back(
1183  makeI64Const(b, (int)desc.getInterleave())); // interleave
1184  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1185  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1186  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1187  arguments.push_back(boxArrayPtr); // box dimensions
1188 
1189  // Set data types of the arguments
1190  SmallVector<Type> argTypes = {
1191  llvmInt64Type, /* int64_t tensorRank */
1192  llvmPointerType, /* ptr */
1193  llvmInt64Type, /* int64_t */
1194  llvmInt64Type, /* int64_t */
1195  llvmInt64Type, /* int64_t */
1196  llvmInt64Type, /* int64_t */
1197  llvmInt64Type, /* int64_t */
1198  llvmPointerType /* ptr */
1199  };
1200  FunctionCallBuilder hostRegisterCallBuilder = {
1201  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1202  Value tensorMap =
1203  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1204 
1205  rewriter.replaceOp(op, tensorMap);
1206  return success();
1207  }
1208 };
1209 
1210 struct NVGPUWarpgroupMmaOpLowering
1211  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1213 
1214  /// This is a helper class to generate required NVVM Ops for warp-group level
1215  /// matrix multiplication.
1216  /// When the given GEMM shape is larger than the shape of
1217  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1218  /// Op(s), group and execute them asynchronously. The class also handles
1219  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1220  /// create descriptors for each instruction.
1221  ///
1222  /// For example this is the case when the shape of GEMM is 128x128x128
1223  ///
1224  /// nvvm.wgmma.fence.aligned
1225  ///
1226  /// nvvm.wgmma.mma.async descA, descB
1227  /// iterate(descA, descB)
1228  /// nvvm.wgmma.mma.async descA, descB
1229  /// [6x times more]
1230  ///
1231  /// nvvm.wgmma.group.sync.aligned
1232  /// nvvm.wgmma.wait.group.sync [groupId]
1233  ///
1234  class WarpgroupGemm {
1235  nvgpu::WarpgroupMmaOp op;
1237  OpAdaptor adaptor;
1238 
1239  // Entire shape of the given Op
1240  int64_t totalM, totalN, totalK;
1241 
1242  // Shape of one wgmma instruction
1243  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1244 
1245  // Iteration counts for GEMM
1246  int iterationM = 0, iterationN = 0, iterationK = 0;
1247 
1248  /// The function returns the shape of wgmma instruction that is defined in
1249  /// PTX programming guide.
1250  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1251  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1252  wgmmaM = 64;
1253  wgmmaN = sizeN;
1254  if (inputElemType.isTF32()) {
1255  wgmmaK = 8;
1256  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1257  wgmmaK = 16;
1258  } else if (inputElemType.isFloat8E4M3FN() ||
1259  inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
1260  wgmmaK = 32;
1261  } else if (inputElemType.isInteger(1)) {
1262  wgmmaK = 256;
1263  } else {
1264  llvm_unreachable("msg: not supported K shape");
1265  }
1266  LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1267  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1268  }
1269 
1270  /// Generates WGMMATypesAttr from MLIR Type
1271  NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1272  bool useF32 = false) const {
1273  auto getWgmmaType = [=](Type elemType) {
1274  if (elemType.isF32() || elemType.isTF32())
1275  return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1276  if (elemType.isF16())
1277  return NVVM::WGMMATypes::f16;
1278  if (elemType.isBF16())
1279  return NVVM::WGMMATypes::bf16;
1280  if (elemType.isFloat8E4M3FN())
1281  return NVVM::WGMMATypes::e4m3;
1282  if (elemType.isFloat8E5M2())
1283  return NVVM::WGMMATypes::e5m2;
1284  if (elemType.isInteger(1))
1285  return NVVM::WGMMATypes::b1;
1286  if (elemType.isInteger(8))
1287  return NVVM::WGMMATypes::s8;
1288  if (elemType.isUnsignedInteger(8))
1289  return NVVM::WGMMATypes::u8;
1290  if (elemType.isInteger(32))
1291  return NVVM::WGMMATypes::s32;
1292  llvm_unreachable("unsupported type");
1293  };
1294  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1295  }
1296 
1297  /// Generates layout attribute for the input matrix for wgmma instruction
1298  NVVM::MMALayoutAttr
1299  generateWgmmaLayout(std::optional<bool> transpose) const {
1300  if (transpose.value_or(false))
1301  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1302  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1303  }
1304 
1305  /// Generates shape attribute for wgmma instruction
1306  NVVM::MMAShapeAttr generateWgmmaShape() const {
1307  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1308  }
1309 
1310  /// Generates scale attributes of output matrix for wgmma instruction
1311  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1313  NVVM::WGMMAScaleOut::one);
1314  }
1315  /// Generates scale attributes of input matrix for wgmma instruction
1316  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1318  NVVM::WGMMAScaleIn::one);
1319  }
1320 
1321  /// Basic function to generate Add
1322  Value makeAdd(Value lhs, Value rhs) {
1323  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1324  };
1325 
1326  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1327  /// Currently, it only handles row-major.
1328  ///
1329  /// It moves the pointer like below for [128][64] size:
1330  /// +2 +4 +6
1331  /// ↓ ↓ ↓
1332  /// descA ---> +--+--+--+--+
1333  /// |->|->|->|->|
1334  /// | | | | |
1335  /// | | | | |
1336  /// | | | | |
1337  /// descA+512---> +-----------+
1338  /// | | | | |
1339  /// | | | | |
1340  /// | | | | |
1341  /// | | | | |
1342  /// +-----------+
1343  ///
1344  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1345  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1346  Type elemA = matrixTypeA.getElementType();
1347  int byte = elemA.getIntOrFloatBitWidth() / 8;
1348  int tileShapeA = matrixTypeA.getDimSize(1);
1349  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1350  incrementVal = incrementVal >> exclude4LSB;
1351  LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1352  << "] [wgmma descriptors] Descriptor A + "
1353  << incrementVal << " | \t ");
1354  if (!incrementVal)
1355  return desc;
1356  return makeAdd(desc, makeI64Const(b, incrementVal));
1357  }
1358 
1359  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1360  /// Currently, it only handles column-major.
1361  ///
1362  /// It moves the pointer like below for [128][64] size:
1363  /// descB ---> +--+--+--+--+--+--+--+--+
1364  /// |↓ | | | | | | | |
1365  /// |↓ | | | | | | | |
1366  /// |↓ | | | | | | | |
1367  /// |↓ | | | | | | | |
1368  /// +--+--+--+--+--+--+--+--+
1369  ///
1370  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1371  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1372  Type elemB = matrixTypeB.getElementType();
1373  int byte = elemB.getIntOrFloatBitWidth() / 8;
1374  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1375  incrementVal = incrementVal >> exclude4LSB;
1376  LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1377  if (!incrementVal)
1378  return desc;
1379  return makeAdd(desc, makeI64Const(b, incrementVal));
1380  }
1381 
1382  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1383  /// descriptors and arranges them based on induction variables: i, j, and k.
1384  Value generateWgmma(int i, int j, int k, Value matrixC) {
1385  LLVM_DEBUG(DBGS() << "\t wgmma."
1386  << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1387  << "(A[" << (iterationM * wgmmaM) << ":"
1388  << (iterationM * wgmmaM) + wgmmaM << "]["
1389  << (iterationK * wgmmaK) << ":"
1390  << (iterationK * wgmmaK + wgmmaK) << "] * "
1391  << " B[" << (iterationK * wgmmaK) << ":"
1392  << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1393  << wgmmaN << "])\n");
1394 
1395  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1396  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1397 
1398  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1399  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1400 
1401  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1402  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1403 
1404  Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1405  NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1406 
1407  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1408  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1409  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1410  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1411  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1412 
1413  auto overflow = NVVM::MMAIntOverflowAttr::get(
1414  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1415 
1416  return b.create<NVVM::WgmmaMmaAsyncOp>(
1417  matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1418  itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1419  overflow);
1420  }
1421 
1422  /// Generates multiple wgmma instructions to complete the given GEMM shape
1423  Value generateWgmmaGroup() {
1424  Value wgmmaResult =
1425  b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1426 
1427  // Perform GEMM
1428  SmallVector<Value> wgmmaResults;
1429  for (int i = 0; i < iterationM; ++i) {
1430  Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1431  for (int j = 0; j < iterationN; ++j)
1432  for (int k = 0; k < iterationK; ++k)
1433  matrixC = generateWgmma(i, j, k, matrixC);
1434  wgmmaResults.push_back(matrixC);
1435  }
1436  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1437  wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1438  wgmmaResult, matrix, idx);
1439  }
1440  return wgmmaResult;
1441  }
1442 
1443  public:
1444  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1445  OpAdaptor adaptor)
1446  : op(op), b(b), adaptor(adaptor) {
1447  // Find the entire GEMM Shape
1448  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1449  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1450  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1451  LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1452  << "] += A[" << totalM << "][" << totalK << "] * B["
1453  << totalK << "][" << totalN << "] ---===\n");
1454 
1455  // Find the shape for one wgmma instruction
1456  findWgmmaShape(
1457  totalM, totalN,
1458  op.getDescriptorA().getType().getTensor().getElementType());
1459 
1460  // Iterations counts to complete the given shape with wgmma shape
1461  iterationM = totalM / wgmmaM;
1462  iterationN = totalN / wgmmaN;
1463  iterationK = totalK / wgmmaK;
1464  }
1465 
1466  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1467  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1468  /// instructions and group synchronization, as well as waiting
1469  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1470  /// (WgmmaWaitGroupSyncOp) after the instructions.
1471  Value generateWarpgroupMma() {
1472  b.create<NVVM::WgmmaFenceAlignedOp>();
1473  Value wgmmaResult = generateWgmmaGroup();
1474  b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1475  b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1476  return wgmmaResult;
1477  }
1478  };
1480  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1481  ConversionPatternRewriter &rewriter) const override {
1482  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1483 
1484  // Step 1. Build a helper class
1485  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1486 
1487  // Step 2. Get the entire GEMM Shape
1488  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1489 
1490  // Step 3. Replace fragmented result struct with the op results
1491  rewriter.replaceOp(op, wgmmaResult);
1492  return success();
1493  }
1494 };
1495 
1496 struct NVGPUWarpgroupMmaStoreOpLowering
1497  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1498  using ConvertOpToLLVMPattern<
1499  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1500 
1501  /// This function stores a fragmented register matrix owned by a warp group
1502  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1503  /// of a struct.
1504  /// Here is what each threads (T) holds, each `d` is struct value with a
1505  /// number.
1506  ///
1507  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1508  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1509  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1510  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1511  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1512  ///
1513  /// Matrix-D:
1514  /// +______________________________________________________________________+
1515  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1516  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1517  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1518  /// ..| .........|.........|.........|.........|........|...........|........|
1519  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1520  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1521  /// ..| .........|.........|.........|.........|........|...........|........|
1522  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1523  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1524  /// ..| .........|.........|.........|.........|........|...........|........|
1525  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1526  /// ..| .........|.........|.........|.........|........|...........|........|
1527  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1528  /// ..| .........|.........|.........|.........|........|...........|........|
1529  /// +______________________________________________________________________+
1530  ///
1531  /// \param rewriter: The pattern rewriter.
1532  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1533  /// matrix). It is holded by a thread and a struct with 64 elements.
1534  /// \param dstMemref: The memref where the registers will be stored.
1535  /// \param offset: the offset within the memref where the registers will be
1536  /// stored.
1537  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1538  TypedValue<MemRefType> dstMemref,
1539  int offset) const {
1540  Type i32 = b.getI32Type();
1541 
1542  auto makeConst = [&](int32_t index) -> Value {
1543  return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1544  };
1545  Value c1 = makeConst(1);
1546  Value c2 = makeConst(2);
1547  Value c4 = makeConst(4);
1548  Value c8 = makeConst(8);
1549  Value c16 = makeConst(16);
1550  Value warpSize = makeConst(kWarpSize);
1551 
1552  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1553  return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1554  };
1555  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1556  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1557  };
1558 
1559  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1561  Type it = b.getIndexType();
1562  Value idx = b.create<arith::IndexCastOp>(it, x);
1563  Value idy0 = b.create<arith::IndexCastOp>(it, y);
1564  Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1565  Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1566  Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1567  b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1568  b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1569  };
1570 
1571  Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1572  Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1573  Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1574  Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1575  Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1576 
1577  Value tj = makeMul(lane4modId, c2);
1578  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1579  if (offset)
1580  ti = makeAdd(ti, makeConst(offset));
1581 
1582  auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
1583 
1584  // Number of 32-bit registers owns per thread
1585  constexpr unsigned numAdjacentRegisters = 2;
1586  // Number of 8x8 matrices one below another per warp
1587  constexpr unsigned numStackedMatrices = 2;
1588 
1589  size_t storeCount = (structType.getBody().size() /
1590  (numStackedMatrices * numAdjacentRegisters));
1591 
1592  for (size_t i = 0; i < numStackedMatrices; ++i) {
1593  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1594  for (size_t j = 0; j < storeCount; ++j) {
1595  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1596  size_t structIndex = (i * numAdjacentRegisters) +
1597  (j * (numStackedMatrices * numAdjacentRegisters));
1598  makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1599  }
1600  }
1601  }
1602 
1604  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1605  ConversionPatternRewriter &rewriter) const override {
1606  int offset = 0;
1607  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1608  Value matriDValue = adaptor.getMatrixD();
1609  auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
1610  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1611  auto structType = matrixD.cast<LLVM::LLVMStructType>();
1612  Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1613  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1614  offset += structType.getBody().size();
1615  }
1616  rewriter.eraseOp(op);
1617  return success();
1618  }
1619 };
1620 
1621 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1622  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1623  using ConvertOpToLLVMPattern<
1624  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1626  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1627  ConversionPatternRewriter &rewriter) const override {
1628  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1629  LLVM::LLVMStructType packStructType =
1630  getTypeConverter()
1631  ->convertType(op.getMatrixC().getType())
1632  .cast<LLVM::LLVMStructType>();
1633  Type elemType = packStructType.getBody()
1634  .front()
1635  .cast<LLVM::LLVMStructType>()
1636  .getBody()
1637  .front();
1638  Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1639  Value packStruct = b.create<LLVM::UndefOp>(packStructType);
1640  SmallVector<Value> innerStructs;
1641  // Unpack the structs and set all values to zero
1642  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1643  auto structType = s.cast<LLVM::LLVMStructType>();
1644  Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1645  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1646  structValue = b.create<LLVM::InsertValueOp>(
1647  structType, structValue, zero, ArrayRef<int64_t>({i}));
1648  }
1649  innerStructs.push_back(structValue);
1650  }
1651  // Pack the inner structs into a single struct
1652  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1653  packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1654  packStruct, matrix, idx);
1655  }
1656  rewriter.replaceOp(op, packStruct);
1657  return success();
1658  }
1659 };
1660 
1661 struct NVGPUTmaPrefetchOpLowering
1662  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1665  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1666  ConversionPatternRewriter &rewriter) const override {
1667  rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1668  op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1669  return success();
1670  }
1671 };
1672 
1673 } // namespace
1674 
1676  RewritePatternSet &patterns) {
1677  patterns.add<
1678  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1679  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1680  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1681  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1682  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1683  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1684  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1685  NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1686  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1687  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1688  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1689  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1690  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1691  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1692  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1693  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1694  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1695  NVGPUMmaSparseSyncLowering>(converter);
1696 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:27
constexpr int kWarpSize
Definition: NVGPUDialect.h:24
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:50
#define DBGSE()
Definition: NVGPUToNVVM.cpp:35
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:60
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:46
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
#define DBGS()
Definition: NVGPUToNVVM.cpp:34
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
Definition: NVGPUToNVVM.cpp:99
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:81
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
FloatType getF32Type()
Definition: Builders.cpp:63
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
FloatType getF16Type()
Definition: Builders.cpp:59
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
IntegerType getI8Type()
Definition: Builders.cpp:79
FloatType getF64Type()
Definition: Builders.cpp:65
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:109
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:490
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
bool isTF32() const
Definition: Types.cpp:50
U cast() const
Definition: Types.h:340
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isFloat8E4M3FN() const
Definition: Types.cpp:38
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:67
bool isF32() const
Definition: Types.cpp:51
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:91
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isFloat8E5M2() const
Definition: Types.cpp:37
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
bool isBF16() const
Definition: Types.cpp:48
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
U cast() const
Definition: Value.h:116
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:959
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.