MLIR  20.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
37 
38 namespace mlir {
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 /// Number of bits that needs to be excluded when building matrix descriptor for
46 /// wgmma operations.
47 constexpr int exclude4LSB = 4;
48 
49 /// GPU has 32 bit registers, this function truncates values when larger width
50 /// is not needed.
52  Type type = value.getType();
53  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
54  if (type.getIntOrFloatBitWidth() <= 32)
55  return value;
56  return b.create<LLVM::TruncOp>(b.getI32Type(), value);
57 }
58 
59 /// Returns the type for the intrinsic given the vectorResultType of the
60 /// `gpu.mma.sync` operation.
61 static Type inferIntrinsicResultType(Type vectorResultType) {
62  MLIRContext *ctx = vectorResultType.getContext();
63  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
64  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
65  auto i32Ty = IntegerType::get(ctx, 32);
66  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
67  Type f64Ty = Float64Type::get(ctx);
68  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
69  Type f32Ty = Float32Type::get(ctx);
70  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
71  if (a.getElementType() == f16x2Ty) {
72  return LLVM::LLVMStructType::getLiteral(
73  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
74  }
75  if (a.getElementType() == i32x2Ty) {
76  return LLVM::LLVMStructType::getLiteral(
77  ctx,
78  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
79  }
80  if (a.getElementType() == f64x2Ty) {
81  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
82  }
83  if (a.getElementType() == f32x2Ty) {
84  return LLVM::LLVMStructType::getLiteral(
85  ctx,
86  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
87  }
88  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
89  return LLVM::LLVMStructType::getLiteral(
90  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
91  }
92  return vectorResultType;
93 }
94 
95 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
96 /// always an LLVM struct) into a fragment that is compatible with the vector
97 /// type of this operation. This involves extracting elements from the struct
98 /// and inserting them into an LLVM array. These extra data-movement
99 /// operations should be canonicalized away by the LLVM backend.
100 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
101  Type resultType, Value intrinsicResult,
102  RewriterBase &rewriter) {
103  MLIRContext *ctx = rewriter.getContext();
104  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
106  Type i32Ty = rewriter.getI32Type();
107  Type f32Ty = rewriter.getF32Type();
108  Type f64Ty = rewriter.getF64Type();
109  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
110  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
111  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
112  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
113  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
114 
115  auto makeConst = [&](int32_t index) -> Value {
116  return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
117  rewriter.getI32IntegerAttr(index));
118  };
119 
120  if (arrayType) {
121  SmallVector<Value, 4> elements;
122 
123  // The intrinsic returns 32-bit wide elements in a form which can be
124  // directly bitcasted and inserted into the result vector.
125  if (arrayType.getElementType() == f16x2Ty ||
126  arrayType.getElementType() == f32x1Ty) {
127  for (unsigned i = 0; i < structType.getBody().size(); i++) {
128  Value el =
129  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
130  el = rewriter.createOrFold<LLVM::BitcastOp>(
131  loc, arrayType.getElementType(), el);
132  elements.push_back(el);
133  }
134  }
135 
136  // The intrinsic returns i32, f64, and f32 values as individual scalars,
137  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
138  // need to extract them from the struct and pack them into the 64-bit wide
139  // rows of the vector result.
140  if (arrayType.getElementType() == i32x2Ty ||
141  arrayType.getElementType() == f64x2Ty ||
142  arrayType.getElementType() == f32x2Ty) {
143 
144  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145  Value vec =
146  rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
147  Value x1 =
148  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149  Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
150  i * 2 + 1);
151  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
152  x1, makeConst(0));
153  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
154  x2, makeConst(1));
155  elements.push_back(vec);
156  }
157  }
158 
159  // Create the final vectorized result.
160  Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
161  for (const auto &el : llvm::enumerate(elements)) {
162  result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
163  el.index());
164  }
165  return result;
166  }
167 
168  return intrinsicResult;
169 }
170 
171 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
172 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
173 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
174 /// scalars of certain types. This function helps unpack the `vector` arguments
175 /// and cast them to the types expected by `nvvm.mma.sync`.
177  Value operand,
178  NVVM::MMATypes operandPtxType) {
179  SmallVector<Value> result;
180  Type i32Ty = b.getI32Type();
181  Type f64Ty = b.getF64Type();
182  Type f32Ty = b.getF32Type();
183  Type i64Ty = b.getI64Type();
184  Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
185  Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
186  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
187  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
188 
189  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190  Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
191 
192  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
193  // scalar types.
194  if (arrayTy.getElementType() == i8x4Ty ||
195  arrayTy.getElementType() == i4x8Ty ||
196  (arrayTy.getElementType() == f32x1Ty &&
197  operandPtxType == NVVM::MMATypes::tf32)) {
198  result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
199  continue;
200  }
201 
202  // For some element types (i32, f32, f64), we need to unpack the inner
203  // vector/array type as well because the intrinsic expects individual
204  // scalars to be provided.
205  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207  innerArrayTy.getElementType() == f64Ty ||
208  innerArrayTy.getElementType() == f32Ty)) {
209  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210  idx < innerSize; idx++) {
211  result.push_back(b.create<LLVM::ExtractElementOp>(
212  toUse,
213  b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
214  }
215  continue;
216  }
217  result.push_back(toUse);
218  }
219  return result;
220 }
221 
222 /// Returns whether mbarrier object has shared memory address space.
223 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
224  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225  barrierType.getMemorySpace()));
226 }
227 
228 /// Returns the memory space attribute of the mbarrier object.
230  nvgpu::MBarrierGroupType barrierType) {
231  Attribute memorySpace = {};
232  if (isMbarrierShared(barrierType)) {
233  memorySpace =
234  IntegerAttr::get(IntegerType::get(context, 64),
235  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
236  }
237  return memorySpace;
238 }
239 
240 /// Returns memref type of the mbarrier object. The type is defined in the
241 /// MBarrierGroupType.
242 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
243  nvgpu::MBarrierGroupType barrierType) {
244  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
245  MemRefLayoutAttrInterface layout;
246  return MemRefType::get({barrierType.getNumBarriers()},
247  IntegerType::get(context, 64), layout, memorySpace);
248 }
249 
250 namespace {
251 
252 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
254 
255  LogicalResult
256  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
257  ConversionPatternRewriter &rewriter) const override {
258  MLIRContext *ctx = getContext();
259  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
260 
261  // The result type of ldmatrix will always be a struct of 32bit integer
262  // registers if more than one 32bit value is returned. Otherwise, the result
263  // is a single i32. The result type of the GPU operation is always a vector
264  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
265  // vector type of the result and always 32 bits long. We bitcast the result
266  // of the NVVM::LdMatrix to this vector type.
267  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268  if (!vectorResultType) {
269  return failure();
270  }
271  Type innerVectorType = LLVM::getFixedVectorType(
272  vectorResultType.getElementType(), vectorResultType.getDimSize(1));
273 
274  int64_t num32BitRegs = vectorResultType.getDimSize(0);
275 
276  Type ldMatrixResultType;
277  if (num32BitRegs > 1) {
278  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
279  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
280  } else {
281  ldMatrixResultType = rewriter.getI32Type();
282  }
283 
284  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285  Value srcPtr =
286  getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287  adaptor.getIndices(), rewriter);
288  Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289  ldMatrixResultType, srcPtr,
290  /*num=*/op.getNumTiles(),
291  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
292  : NVVM::MMALayout::row);
293 
294  // The ldmatrix operation returns either a single i32 value or a struct of
295  // i32 values. Here we unpack those values and cast them back to their
296  // actual vector type (still of width 32b) and repack them into a result
297  // struct.
298  Type finalResultType = typeConverter->convertType(vectorResultType);
299  Value result = b.create<LLVM::UndefOp>(finalResultType);
300  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301  Value i32Register =
302  num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303  : ldMatrixResult;
304  Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305  result = b.create<LLVM::InsertValueOp>(result, casted, i);
306  }
307 
308  rewriter.replaceOp(op, result);
309  return success();
310  }
311 };
312 
313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314 /// enum).
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316  Type elType = getElementTypeOrSelf(t);
317  if (elType.isInteger(8))
318  return NVVM::MMATypes::s8;
319  if (elType.isInteger(4))
320  return NVVM::MMATypes::s4;
321  if (elType.isF16())
322  return NVVM::MMATypes::f16;
323  if (elType.isF64())
324  return NVVM::MMATypes::f64;
325  if (elType.isF32())
326  return NVVM::MMATypes::tf32;
327  return failure();
328 }
329 
330 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
332 
333  LogicalResult
334  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override {
336  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337  // Get the shapes of the MMAMatrix type being used. The shapes will
338  // choose which intrinsic this op will be lowered to.
339  VectorType aType = op.getMatrixA().getType();
340  VectorType bType = op.getMatrixA().getType();
341  VectorType cType = op.getMatrixC().getType();
342 
343  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344 
345  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347  if (aType.getElementType().isF32() && !tf32Enabled)
348  return failure();
349 
350  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351  if (failed(ptxTypeA))
352  return op->emitOpError("failed to deduce operand PTX types");
353  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354  if (failed(ptxTypeB))
355  return op->emitOpError("failed to deduce operand PTX types");
356  std::optional<NVVM::MMATypes> ptxTypeC =
357  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358  /*isAccumulator=*/true);
359  if (!ptxTypeC)
360  return op->emitError(
361  "could not infer the PTX type for the accumulator/result");
362 
363  // TODO: add an attribute to the op to customize this behavior.
364  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365  if (isa<IntegerType>(aType.getElementType()))
366  overflow = NVVM::MMAIntOverflow::satfinite;
367 
368  SmallVector<Value> matA =
369  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370  SmallVector<Value> matB =
371  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372  SmallVector<Value> matC =
373  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374 
375  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376  Type intrinsicResTy = inferIntrinsicResultType(
377  typeConverter->convertType(op->getResultTypes()[0]));
378  Value intrinsicResult = b.create<NVVM::MmaOp>(
379  intrinsicResTy, matA, matB, matC,
380  /*shape=*/gemmShape,
381  /*b1Op=*/std::nullopt,
382  /*intOverflow=*/overflow,
383  /*multiplicandPtxTypes=*/
384  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385  /*multiplicandLayouts=*/
386  std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387  NVVM::MMALayout::col});
388  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389  desiredRetTy, intrinsicResult,
390  rewriter));
391  return success();
392  }
393 };
394 
395 struct ConvertNVGPUToNVVMPass
396  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397  using Base::Base;
398 
399  void getDependentDialects(DialectRegistry &registry) const override {
400  registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401  arith::ArithDialect>();
402  }
403 
404  void runOnOperation() override {
407  LLVMTypeConverter converter(&getContext(), options);
408  IRRewriter rewriter(&getContext());
410  converter, [](gpu::AddressSpace space) -> unsigned {
411  switch (space) {
412  case gpu::AddressSpace::Global:
413  return static_cast<unsigned>(
415  case gpu::AddressSpace::Workgroup:
416  return static_cast<unsigned>(
418  case gpu::AddressSpace::Private:
419  return 0;
420  }
421  llvm_unreachable("unknown address space enum value");
422  return 0;
423  });
424  /// device-side async tokens cannot be materialized in nvvm. We just
425  /// convert them to a dummy i32 type in order to easily drop them during
426  /// conversion.
427  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
428  return converter.convertType(IntegerType::get(type.getContext(), 32));
429  });
430  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
431  Type elemType = type.getFragmented().getElementType();
432  int64_t sizeM = type.getFragmented().getDimSize(0);
433  int64_t sizeN = type.getFragmented().getDimSize(1);
434 
435  unsigned numMembers;
436  if (elemType.isF32() || elemType.isInteger(32))
437  numMembers = sizeN / 2;
438  else if (elemType.isF16())
439  numMembers = sizeN / 4;
440  else
441  llvm_unreachable("unsupported type for warpgroup accumulator");
442 
443  SmallVector<Type> innerStructBody;
444  for (unsigned i = 0; i < numMembers; i++)
445  innerStructBody.push_back(elemType);
446  auto innerStructType =
447  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
448 
449  SmallVector<Type> structBody;
450  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
451  structBody.push_back(innerStructType);
452 
453  auto convertedType =
454  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
455  return converter.convertType(convertedType);
456  });
457  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
458  return converter.convertType(IntegerType::get(type.getContext(), 64));
459  });
460  converter.addConversion(
461  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
462  return converter.convertType(IntegerType::get(type.getContext(), 64));
463  });
464  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
465  return converter.convertType(
466  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
467  });
468  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
469  return LLVM::LLVMPointerType::get(type.getContext());
470  });
473  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474  target.addLegalDialect<::mlir::arith::ArithDialect>();
475  target.addLegalDialect<::mlir::memref::MemRefDialect>();
476  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
478  converter, patterns, target);
479  if (failed(applyPartialConversion(getOperation(), target,
480  std::move(patterns))))
481  signalPassFailure();
482  }
483 };
484 
485 /// Returns the constraints for the sparse MMA inline assembly instruction.
486 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
487  unsigned matBSize,
488  unsigned matCSize) {
489  std::string str;
490  llvm::raw_string_ostream ss(str);
491  for (unsigned i = 0; i < matCSize; i++)
492  ss << "=r,";
493  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
494  ss << "r,";
495  // The final operand is for the sparsity metadata.
496  // The sparsity selector appears as direct literal.
497  ss << "r";
498  return str;
499 }
500 
501 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
502 /// the given parameters. Note that this function doesn't do any validation,
503 /// it's expected that the provided parameters correspond to a valid
504 /// instruction.
505 static std::string buildMmaSparseAsmString(
506  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
507  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
510  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511  return NVVM::stringifyMMATypes(ptxType);
512  };
513 
514  std::string asmStr;
515  llvm::raw_string_ostream ss(asmStr);
516  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
517  << shape[2] << ".row.col.";
518 
519  if (overflow)
520  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
521 
522  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
523  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
524  unsigned asmArgIdx = 0;
525 
526  // The operand string is structured into sections `{matC elements...},
527  // {matA elements...}, {matB elements...}, {matC elements}`.
528  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529  ss << "{";
530  for (unsigned i = 0; i < arrSize; i++)
531  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
532  ss << "},";
533  }
534  ss << "$" << asmArgIdx++ << ",";
535  assert(metaDataSelector <= 1);
536  ss << "0x" << metaDataSelector << ";";
537  return asmStr;
538 }
539 
540 /// Builds an inline assembly operation corresponding to the specified MMA
541 /// sparse sync operation.
542 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
543  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
544  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
545  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
546  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
547  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
548  Type intrinsicResultType) {
549  auto asmDialectAttr =
550  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
551 
552  const unsigned matASize = unpackedAData.size();
553  const unsigned matBSize = unpackedB.size();
554  const unsigned matCSize = unpackedC.size();
555 
556  std::string asmStr = buildMmaSparseAsmString(
557  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
558  ptxTypeD, overflow, metadataSelector);
559  std::string constraintStr =
560  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
561 
562  SmallVector<Value> asmVals;
563  asmVals.reserve(matASize + matBSize + matCSize + 1);
564  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
565  llvm::append_range(asmVals, args);
566  asmVals.push_back(indexData);
567 
568  return b.create<LLVM::InlineAsmOp>(
569  /*resultTypes=*/intrinsicResultType,
570  /*operands=*/asmVals,
571  /*asm_string=*/asmStr,
572  /*constraints=*/constraintStr,
573  /*has_side_effects=*/true,
574  /*is_align_stack=*/false,
575  /*asm_dialect=*/asmDialectAttr,
576  /*operand_attrs=*/ArrayAttr());
577 }
578 
579 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
580 struct NVGPUMmaSparseSyncLowering
581  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
583 
584  LogicalResult
585  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
586  ConversionPatternRewriter &rewriter) const override {
587  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
588  // Get the shapes of the MMAMatrix type being used. The shapes will
589  // choose which intrinsic this op will be lowered to.
590  VectorType aType = op.getMatrixA().getType();
591  VectorType bType = op.getMatrixB().getType();
592  VectorType cType = op.getMatrixC().getType();
593 
594  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595  if (failed(ptxTypeA))
596  return op->emitOpError("failed to deduce operand PTX types");
597  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598  if (failed(ptxTypeB))
599  return op->emitOpError("failed to deduce operand PTX types");
600  std::optional<NVVM::MMATypes> ptxTypeC =
601  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
602  /*isAccumulator=*/true);
603  if (!ptxTypeC)
604  return op->emitError(
605  "could not infer the PTX type for the accumulator/result");
606 
607  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
608  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609  if (aType.getElementType().isF32() && !tf32Enabled)
610  return failure();
611 
612  // TODO: add an attribute to the op to customize this behavior.
613  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614  if (isa<IntegerType>(aType.getElementType()))
615  overflow = NVVM::MMAIntOverflow::satfinite;
616 
617  SmallVector<Value> matA =
618  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
619  SmallVector<Value> matB =
620  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
621  SmallVector<Value> matC =
622  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
623 
624  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
625  Type intrinsicResTy = inferIntrinsicResultType(
626  typeConverter->convertType(op->getResultTypes()[0]));
627 
628  // Bitcast the sparse metadata from vector<2xf16> to an i32.
629  Value sparseMetadata = adaptor.getSparseMetadata();
630  if (sparseMetadata.getType() !=
631  LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
632  return op->emitOpError() << "Expected metadata type to be LLVM "
633  "VectorType of 2 i16 elements";
634  sparseMetadata =
635  b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
636 
637  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
638  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
639  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
640  intrinsicResTy);
641  if (failed(intrinsicResult))
642  return failure();
643 
644  assert((*intrinsicResult).getNumResults() == 1 &&
645  "expected inline asm op returns a single LLVM struct type");
646  rewriter.replaceOp(
647  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
648  (*intrinsicResult)->getResult(0), rewriter));
649  return success();
650  }
651 };
652 
653 struct NVGPUAsyncCopyLowering
654  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
656  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
657 
658  LogicalResult
659  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
660  ConversionPatternRewriter &rewriter) const override {
661  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
662  Location loc = op.getLoc();
663  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
664  Value dstPtr =
665  getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
666  adaptor.getDstIndices(), rewriter);
667  FailureOr<unsigned> dstAddressSpace =
668  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
669  if (failed(dstAddressSpace))
670  return rewriter.notifyMatchFailure(
671  loc, "destination memref address space not convertible to integer");
672 
673  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
674  FailureOr<unsigned> srcAddressSpace =
675  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
676  if (failed(srcAddressSpace))
677  return rewriter.notifyMatchFailure(
678  loc, "source memref address space not convertible to integer");
679 
680  Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
681  adaptor.getSrcIndices(), rewriter);
682  // Intrinsics takes a global pointer so we need an address space cast.
683  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
685  scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
686  int64_t dstElements = adaptor.getDstElements().getZExtValue();
687  int64_t sizeInBytes =
688  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
689  // When the optional SrcElements argument is *not* present, the regular
690  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
691  // memory) to fill DstElements number of elements in the destination
692  // (shared memory).
693  Value srcBytes = adaptor.getSrcElements();
694  if (srcBytes) {
695  // When the optional SrcElements argument is present, the source (global
696  // memory) of CpAsyncOp is read only for SrcElements number of elements.
697  // The rest of the DstElements in the destination (shared memory) are
698  // filled with zeros.
699  Value c3I32 =
700  b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
701  Value bitwidth = b.create<LLVM::ConstantOp>(
702  b.getI32Type(),
703  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
704  Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
705  srcBytes = b.create<LLVM::LShrOp>(
706  b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
707  }
708  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
709  // 16 dst bytes.
710  NVVM::LoadCacheModifierKind cacheModifier =
711  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
712  ? NVVM::LoadCacheModifierKind::CG
713  : NVVM::LoadCacheModifierKind::CA;
714 
715  b.create<NVVM::CpAsyncOp>(
716  dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
717  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
718  srcBytes);
719 
720  // Drop the result token.
721  Value zero = b.create<LLVM::ConstantOp>(
722  IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
723  rewriter.replaceOp(op, zero);
724  return success();
725  }
726 };
727 
728 struct NVGPUAsyncCreateGroupLowering
729  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
731  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
732 
733  LogicalResult
734  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
735  ConversionPatternRewriter &rewriter) const override {
736  rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
737  // Drop the result token.
738  Value zero = rewriter.create<LLVM::ConstantOp>(
739  op->getLoc(), IntegerType::get(op.getContext(), 32),
740  rewriter.getI32IntegerAttr(0));
741  rewriter.replaceOp(op, zero);
742  return success();
743  }
744 };
745 
746 struct NVGPUAsyncWaitLowering
747  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
749  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
750 
751  LogicalResult
752  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
753  ConversionPatternRewriter &rewriter) const override {
754  // If numGroup is not present pick 0 as a conservative correct value.
755  int32_t numGroups = adaptor.getNumGroups().value_or(0);
756  rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
757  rewriter.eraseOp(op);
758  return success();
759  }
760 };
761 
762 /// Creates mbarrier object in shared memory
763 struct NVGPUMBarrierCreateLowering
764  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
766 
767  template <typename moduleT>
768  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
769  Operation *funcOp, moduleT moduleOp,
770  MemRefType barrierType) const {
771  SymbolTable symbolTable(moduleOp);
772  OpBuilder::InsertionGuard guard(rewriter);
773  rewriter.setInsertionPoint(&moduleOp.front());
774  auto global = rewriter.create<memref::GlobalOp>(
775  funcOp->getLoc(), "__mbarrier",
776  /*sym_visibility=*/rewriter.getStringAttr("private"),
777  /*type=*/barrierType,
778  /*initial_value=*/ElementsAttr(),
779  /*constant=*/false,
780  /*alignment=*/rewriter.getI64IntegerAttr(8));
781  symbolTable.insert(global);
782  return global;
783  }
784 
785  LogicalResult
786  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
787  ConversionPatternRewriter &rewriter) const override {
788  Operation *funcOp = op->getParentOp();
789  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
790  rewriter.getContext(), op.getBarriers().getType());
791 
792  memref::GlobalOp global;
793  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
794  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
795  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
796  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
797 
798  rewriter.setInsertionPoint(op);
799  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
800  global.getName());
801  return success();
802  }
803 };
804 
805 /// Base class for lowering mbarrier operations to nvvm intrinsics.
806 template <typename SourceOp>
807 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
808 public:
810  /// Returns the base pointer of the mbarrier object.
811  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
812  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
813  Value mbarId,
814  ConversionPatternRewriter &rewriter) const {
815  MemRefType mbarrierMemrefType =
816  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
818  b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
819  }
820 };
821 
822 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
823 struct NVGPUMBarrierInitLowering
824  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
825  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
826 
827  LogicalResult
828  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
829  ConversionPatternRewriter &rewriter) const override {
830  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
831  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
832  rewriter.setInsertionPoint(op);
833  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
834  adaptor.getMbarId(), rewriter);
835  Value count = truncToI32(b, adaptor.getCount());
836  if (isMbarrierShared(mbarrierType)) {
837  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
838  op, barrier, count, adaptor.getPredicate());
839  } else {
840  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
841  adaptor.getPredicate());
842  }
843  return success();
844  }
845 };
846 
847 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
848 struct NVGPUMBarrierArriveLowering
849  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
850  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
851  LogicalResult
852  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
853  ConversionPatternRewriter &rewriter) const override {
854  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
855  Value barrier =
856  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
857  adaptor.getMbarId(), rewriter);
858  Type tokenType = getTypeConverter()->convertType(
859  nvgpu::MBarrierTokenType::get(op->getContext()));
860  if (isMbarrierShared(op.getBarriers().getType())) {
861  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
862  barrier);
863  } else {
864  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
865  barrier);
866  }
867  return success();
868  }
869 };
870 
871 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
872 /// `nvvm.mbarrier.arrive.nocomplete`
873 struct NVGPUMBarrierArriveNoCompleteLowering
874  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
875  using MBarrierBasePattern<
876  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
877  LogicalResult
878  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
879  ConversionPatternRewriter &rewriter) const override {
880  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
881  Value barrier =
882  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
883  adaptor.getMbarId(), rewriter);
884  Type tokenType = getTypeConverter()->convertType(
885  nvgpu::MBarrierTokenType::get(op->getContext()));
886  Value count = truncToI32(b, adaptor.getCount());
887  if (isMbarrierShared(op.getBarriers().getType())) {
888  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
889  op, tokenType, barrier, count);
890  } else {
891  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
892  op, tokenType, barrier, count);
893  }
894  return success();
895  }
896 };
897 
898 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
899 struct NVGPUMBarrierTestWaitLowering
900  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
901  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
902  LogicalResult
903  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
904  ConversionPatternRewriter &rewriter) const override {
905  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
906  Value barrier =
907  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
908  adaptor.getMbarId(), rewriter);
909  Type retType = rewriter.getI1Type();
910  if (isMbarrierShared(op.getBarriers().getType())) {
911  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
912  op, retType, barrier, adaptor.getToken());
913  } else {
914  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
915  op, retType, barrier, adaptor.getToken());
916  }
917  return success();
918  }
919 };
920 
921 struct NVGPUMBarrierArriveExpectTxLowering
922  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
923  using MBarrierBasePattern<
924  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
925  LogicalResult
926  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
927  ConversionPatternRewriter &rewriter) const override {
928  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
929  Value barrier =
930  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
931  adaptor.getMbarId(), rewriter);
932  Value txcount = truncToI32(b, adaptor.getTxcount());
933 
934  if (isMbarrierShared(op.getBarriers().getType())) {
935  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
936  op, barrier, txcount, adaptor.getPredicate());
937  return success();
938  }
939 
940  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
941  op, barrier, txcount, adaptor.getPredicate());
942  return success();
943  }
944 };
945 
946 struct NVGPUMBarrierTryWaitParityLowering
947  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
948  using MBarrierBasePattern<
949  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
950  LogicalResult
951  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
952  ConversionPatternRewriter &rewriter) const override {
953  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
954  Value barrier =
955  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
956  adaptor.getMbarId(), rewriter);
957  Value ticks = truncToI32(b, adaptor.getTicks());
958  Value phase =
959  b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
960 
961  if (isMbarrierShared(op.getBarriers().getType())) {
962  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
963  op, barrier, phase, ticks);
964  return success();
965  }
966 
967  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
968  phase, ticks);
969  return success();
970  }
971 };
972 
973 struct NVGPUTmaAsyncLoadOpLowering
974  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
975  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
976  LogicalResult
977  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
978  ConversionPatternRewriter &rewriter) const override {
979  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
980  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
981  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
982  adaptor.getDst(), {}, rewriter);
983  Value barrier =
984  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
985  adaptor.getMbarId(), rewriter);
986 
987  SmallVector<Value> coords = adaptor.getCoordinates();
988  for (auto [index, value] : llvm::enumerate(coords)) {
989  coords[index] = truncToI32(b, value);
990  }
991  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
992  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
993  ValueRange{}, adaptor.getMulticastMask(), Value{},
994  adaptor.getPredicate());
995  return success();
996  }
997 };
998 
999 struct NVGPUTmaAsyncStoreOpLowering
1000  : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1001  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1002  LogicalResult
1003  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1004  ConversionPatternRewriter &rewriter) const override {
1005  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1006  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1007  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1008  adaptor.getSrc(), {}, rewriter);
1009  SmallVector<Value> coords = adaptor.getCoordinates();
1010  for (auto [index, value] : llvm::enumerate(coords)) {
1011  coords[index] = truncToI32(b, value);
1012  }
1013 
1014  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1015  op, adaptor.getTensorMapDescriptor(), dest, coords,
1016  adaptor.getPredicate());
1017  return success();
1018  }
1019 };
1020 
1021 struct NVGPUGenerateWarpgroupDescriptorLowering
1022  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1023  using ConvertOpToLLVMPattern<
1024  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1025 
1026  LogicalResult
1027  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1028  ConversionPatternRewriter &rewriter) const override {
1029 
1030  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1031 
1032  nvgpu::TensorMapSwizzleKind swizzleKind =
1033  op.getTensorMap().getType().getSwizzle();
1034 
1035  unsigned layout =
1036  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1037  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1038  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1039  : 1;
1040  unsigned swizzle =
1041  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1042  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1043  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1044  : 0;
1045 
1046  auto ti64 = b.getIntegerType(64);
1047  auto makeConst = [&](uint64_t index) -> Value {
1048  return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1049  };
1050  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1051  return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1052  };
1053  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1054  return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1055  };
1056  auto insertBit = [&](Value desc, Value val, int startBit) {
1057  return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1058  };
1059 
1060  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1061  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1062  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1063  uint64_t offsetVal = 0;
1064 
1065  Value strideDim = makeConst(strideDimVal);
1066  Value leadDim = makeConst(leadDimVal);
1067 
1068  Value baseAddr = getStridedElementPtr(
1069  op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1070  adaptor.getTensor(), {}, rewriter);
1071  Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1072  // Just use 14 bits for base address
1073  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1074 
1075  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1076  startLeadBit = 16, startBaseAddrBit = 0;
1077  Value dsc = makeConst(0);
1078  // // [62,64) swizzle type
1079  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1080  // // [49,52) base_offset
1081  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1082  // // [32,46) stride
1083  dsc = insertBit(dsc, strideDim, startStrideBit);
1084  // // [16,30) leading dimension
1085  dsc = insertBit(dsc, leadDim, startLeadBit);
1086  // // [0,14) start_address
1087  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1088 
1089  LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1090  << "leading_off:" << leadDimVal << "\t"
1091  << "stride_off :" << strideDimVal << "\t"
1092  << "base_offset:" << offsetVal << "\t"
1093  << "layout_type:" << swizzle << " ("
1094  << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1095  << ")\n start_addr : " << baseAddr << "\n");
1096 
1097  rewriter.replaceOp(op, dsc);
1098  return success();
1099  }
1100 };
1101 
1102 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1103  return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1104  b.getI32IntegerAttr(index));
1105 }
1106 
1107 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1108 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1109  // Enum is from CUDA driver API
1110  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1111  enum CUtensorMapDataTypeEnum {
1112  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1113  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1114  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1115  CU_TENSOR_MAP_DATA_TYPE_INT32,
1116  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1117  CU_TENSOR_MAP_DATA_TYPE_INT64,
1118  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1119  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1120  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1121  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1122  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1123  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1124  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1125  };
1126 
1127  if (type.isUnsignedInteger(8))
1128  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1129  if (type.isUnsignedInteger(16))
1130  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1131  if (type.isUnsignedInteger(32))
1132  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1133  if (type.isUnsignedInteger(64))
1134  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1135  if (type.isSignlessInteger(32))
1136  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1137  if (type.isSignlessInteger(64))
1138  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1139  if (type.isF16())
1140  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1141  if (type.isF32())
1142  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1143  if (type.isF64())
1144  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1145  if (type.isBF16())
1146  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1147 
1148  llvm_unreachable("Not supported data type");
1149 }
1150 
1151 struct NVGPUTmaCreateDescriptorOpLowering
1152  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1153  using ConvertOpToLLVMPattern<
1154  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1155  LogicalResult
1156  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1157  ConversionPatternRewriter &rewriter) const override {
1158  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1159  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1160  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1161 
1162  Value tensorElementType =
1163  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1164  auto promotedOperands = getTypeConverter()->promoteOperands(
1165  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1166 
1167  Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1168  makeI64Const(b, 5));
1169  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1170  Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1171  boxArrayPtr, makeI64Const(b, index));
1172  b.create<LLVM::StoreOp>(value, gep);
1173  }
1174 
1175  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1176  // Set Arguments for the function call
1177  SmallVector<Value> arguments;
1178  arguments.push_back(promotedOperands[0]); // rank
1179  arguments.push_back(promotedOperands[1]); // descriptor
1180  arguments.push_back(tensorElementType); // data type
1181  arguments.push_back(
1182  makeI64Const(b, (int)desc.getInterleave())); // interleave
1183  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1184  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1185  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1186  arguments.push_back(boxArrayPtr); // box dimensions
1187 
1188  // Set data types of the arguments
1189  SmallVector<Type> argTypes = {
1190  llvmInt64Type, /* int64_t tensorRank */
1191  llvmPointerType, /* ptr */
1192  llvmInt64Type, /* int64_t */
1193  llvmInt64Type, /* int64_t */
1194  llvmInt64Type, /* int64_t */
1195  llvmInt64Type, /* int64_t */
1196  llvmInt64Type, /* int64_t */
1197  llvmPointerType /* ptr */
1198  };
1199  FunctionCallBuilder hostRegisterCallBuilder = {
1200  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1201  Value tensorMap =
1202  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1203 
1204  rewriter.replaceOp(op, tensorMap);
1205  return success();
1206  }
1207 };
1208 
1209 struct NVGPUWarpgroupMmaOpLowering
1210  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1212 
1213  /// This is a helper class to generate required NVVM Ops for warp-group level
1214  /// matrix multiplication.
1215  /// When the given GEMM shape is larger than the shape of
1216  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1217  /// Op(s), group and execute them asynchronously. The class also handles
1218  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1219  /// create descriptors for each instruction.
1220  ///
1221  /// For example this is the case when the shape of GEMM is 128x128x128
1222  ///
1223  /// nvvm.wgmma.fence.aligned
1224  ///
1225  /// nvvm.wgmma.mma.async descA, descB
1226  /// iterate(descA, descB)
1227  /// nvvm.wgmma.mma.async descA, descB
1228  /// [6x times more]
1229  ///
1230  /// nvvm.wgmma.group.sync.aligned
1231  /// nvvm.wgmma.wait.group.sync [groupId]
1232  ///
1233  class WarpgroupGemm {
1234  nvgpu::WarpgroupMmaOp op;
1236  OpAdaptor adaptor;
1237 
1238  // Entire shape of the given Op
1239  int64_t totalM, totalN, totalK;
1240 
1241  // Shape of one wgmma instruction
1242  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1243 
1244  // Iteration counts for GEMM
1245  int iterationM = 0, iterationN = 0, iterationK = 0;
1246 
1247  /// The function returns the shape of wgmma instruction that is defined in
1248  /// PTX programming guide.
1249  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1250  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1251  wgmmaM = 64;
1252  wgmmaN = sizeN;
1253  if (inputElemType.isTF32()) {
1254  wgmmaK = 8;
1255  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1256  wgmmaK = 16;
1257  } else if (inputElemType.isFloat8E4M3FN() ||
1258  inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
1259  wgmmaK = 32;
1260  } else if (inputElemType.isInteger(1)) {
1261  wgmmaK = 256;
1262  } else {
1263  llvm_unreachable("msg: not supported K shape");
1264  }
1265  LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1266  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1267  }
1268 
1269  /// Generates WGMMATypesAttr from MLIR Type
1270  NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1271  bool useF32 = false) const {
1272  auto getWgmmaType = [=](Type elemType) {
1273  if (elemType.isF32() || elemType.isTF32())
1274  return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1275  if (elemType.isF16())
1276  return NVVM::WGMMATypes::f16;
1277  if (elemType.isBF16())
1278  return NVVM::WGMMATypes::bf16;
1279  if (elemType.isFloat8E4M3FN())
1280  return NVVM::WGMMATypes::e4m3;
1281  if (elemType.isFloat8E5M2())
1282  return NVVM::WGMMATypes::e5m2;
1283  if (elemType.isInteger(1))
1284  return NVVM::WGMMATypes::b1;
1285  if (elemType.isInteger(8))
1286  return NVVM::WGMMATypes::s8;
1287  if (elemType.isUnsignedInteger(8))
1288  return NVVM::WGMMATypes::u8;
1289  if (elemType.isInteger(32))
1290  return NVVM::WGMMATypes::s32;
1291  llvm_unreachable("unsupported type");
1292  };
1293  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1294  }
1295 
1296  /// Generates layout attribute for the input matrix for wgmma instruction
1297  NVVM::MMALayoutAttr
1298  generateWgmmaLayout(std::optional<bool> transpose) const {
1299  if (transpose.value_or(false))
1300  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1301  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1302  }
1303 
1304  /// Generates shape attribute for wgmma instruction
1305  NVVM::MMAShapeAttr generateWgmmaShape() const {
1306  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1307  }
1308 
1309  /// Generates scale attributes of output matrix for wgmma instruction
1310  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1311  return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1312  NVVM::WGMMAScaleOut::one);
1313  }
1314  /// Generates scale attributes of input matrix for wgmma instruction
1315  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1316  return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1317  NVVM::WGMMAScaleIn::one);
1318  }
1319 
1320  /// Basic function to generate Add
1321  Value makeAdd(Value lhs, Value rhs) {
1322  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1323  };
1324 
1325  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1326  /// Currently, it only handles row-major.
1327  ///
1328  /// It moves the pointer like below for [128][64] size:
1329  /// +2 +4 +6
1330  /// ↓ ↓ ↓
1331  /// descA ---> +--+--+--+--+
1332  /// |->|->|->|->|
1333  /// | | | | |
1334  /// | | | | |
1335  /// | | | | |
1336  /// descA+512---> +-----------+
1337  /// | | | | |
1338  /// | | | | |
1339  /// | | | | |
1340  /// | | | | |
1341  /// +-----------+
1342  ///
1343  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1344  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1345  Type elemA = matrixTypeA.getElementType();
1346  int byte = elemA.getIntOrFloatBitWidth() / 8;
1347  int tileShapeA = matrixTypeA.getDimSize(1);
1348  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1349  incrementVal = incrementVal >> exclude4LSB;
1350  LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1351  << "] [wgmma descriptors] Descriptor A + "
1352  << incrementVal << " | \t ");
1353  if (!incrementVal)
1354  return desc;
1355  return makeAdd(desc, makeI64Const(b, incrementVal));
1356  }
1357 
1358  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1359  /// Currently, it only handles column-major.
1360  ///
1361  /// It moves the pointer like below for [128][64] size:
1362  /// descB ---> +--+--+--+--+--+--+--+--+
1363  /// |↓ | | | | | | | |
1364  /// |↓ | | | | | | | |
1365  /// |↓ | | | | | | | |
1366  /// |↓ | | | | | | | |
1367  /// +--+--+--+--+--+--+--+--+
1368  ///
1369  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1370  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1371  Type elemB = matrixTypeB.getElementType();
1372  int byte = elemB.getIntOrFloatBitWidth() / 8;
1373  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1374  incrementVal = incrementVal >> exclude4LSB;
1375  LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1376  if (!incrementVal)
1377  return desc;
1378  return makeAdd(desc, makeI64Const(b, incrementVal));
1379  }
1380 
1381  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1382  /// descriptors and arranges them based on induction variables: i, j, and k.
1383  Value generateWgmma(int i, int j, int k, Value matrixC) {
1384  LLVM_DEBUG(DBGS() << "\t wgmma."
1385  << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1386  << "(A[" << (iterationM * wgmmaM) << ":"
1387  << (iterationM * wgmmaM) + wgmmaM << "]["
1388  << (iterationK * wgmmaK) << ":"
1389  << (iterationK * wgmmaK + wgmmaK) << "] * "
1390  << " B[" << (iterationK * wgmmaK) << ":"
1391  << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1392  << wgmmaN << "])\n");
1393 
1394  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1395  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1396 
1397  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1398  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1399 
1400  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1401  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1402 
1403  Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1404  NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1405 
1406  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1407  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1408  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1409  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1410  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1411 
1412  auto overflow = NVVM::MMAIntOverflowAttr::get(
1413  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1414 
1415  return b.create<NVVM::WgmmaMmaAsyncOp>(
1416  matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1417  itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1418  overflow);
1419  }
1420 
1421  /// Generates multiple wgmma instructions to complete the given GEMM shape
1422  Value generateWgmmaGroup() {
1423  Value wgmmaResult =
1424  b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1425 
1426  // Perform GEMM
1427  SmallVector<Value> wgmmaResults;
1428  for (int i = 0; i < iterationM; ++i) {
1429  Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1430  for (int j = 0; j < iterationN; ++j)
1431  for (int k = 0; k < iterationK; ++k)
1432  matrixC = generateWgmma(i, j, k, matrixC);
1433  wgmmaResults.push_back(matrixC);
1434  }
1435  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1436  wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1437  wgmmaResult, matrix, idx);
1438  }
1439  return wgmmaResult;
1440  }
1441 
1442  public:
1443  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1444  OpAdaptor adaptor)
1445  : op(op), b(b), adaptor(adaptor) {
1446  // Find the entire GEMM Shape
1447  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1448  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1449  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1450  LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1451  << "] += A[" << totalM << "][" << totalK << "] * B["
1452  << totalK << "][" << totalN << "] ---===\n");
1453 
1454  // Find the shape for one wgmma instruction
1455  findWgmmaShape(
1456  totalM, totalN,
1457  op.getDescriptorA().getType().getTensor().getElementType());
1458 
1459  // Iterations counts to complete the given shape with wgmma shape
1460  iterationM = totalM / wgmmaM;
1461  iterationN = totalN / wgmmaN;
1462  iterationK = totalK / wgmmaK;
1463  }
1464 
1465  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1466  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1467  /// instructions and group synchronization, as well as waiting
1468  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1469  /// (WgmmaWaitGroupSyncOp) after the instructions.
1470  Value generateWarpgroupMma() {
1471  b.create<NVVM::WgmmaFenceAlignedOp>();
1472  Value wgmmaResult = generateWgmmaGroup();
1473  b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1474  b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1475  return wgmmaResult;
1476  }
1477  };
1478  LogicalResult
1479  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1480  ConversionPatternRewriter &rewriter) const override {
1481  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1482 
1483  // Step 1. Build a helper class
1484  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1485 
1486  // Step 2. Get the entire GEMM Shape
1487  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1488 
1489  // Step 3. Replace fragmented result struct with the op results
1490  rewriter.replaceOp(op, wgmmaResult);
1491  return success();
1492  }
1493 };
1494 
1495 struct NVGPUWarpgroupMmaStoreOpLowering
1496  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1497  using ConvertOpToLLVMPattern<
1498  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1499 
1500  /// This function stores a fragmented register matrix owned by a warp group
1501  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1502  /// of a struct.
1503  /// Here is what each threads (T) holds, each `d` is struct value with a
1504  /// number.
1505  ///
1506  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1507  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1508  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1509  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1510  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1511  ///
1512  /// Matrix-D:
1513  /// +______________________________________________________________________+
1514  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1515  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1516  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1517  /// ..| .........|.........|.........|.........|........|...........|........|
1518  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1519  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1520  /// ..| .........|.........|.........|.........|........|...........|........|
1521  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1522  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1523  /// ..| .........|.........|.........|.........|........|...........|........|
1524  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1525  /// ..| .........|.........|.........|.........|........|...........|........|
1526  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1527  /// ..| .........|.........|.........|.........|........|...........|........|
1528  /// +______________________________________________________________________+
1529  ///
1530  /// \param rewriter: The pattern rewriter.
1531  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1532  /// matrix). It is holded by a thread and a struct with 64 elements.
1533  /// \param dstMemref: The memref where the registers will be stored.
1534  /// \param offset: the offset within the memref where the registers will be
1535  /// stored.
1536  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1537  TypedValue<MemRefType> dstMemref,
1538  int offset) const {
1539  Type i32 = b.getI32Type();
1540 
1541  auto makeConst = [&](int32_t index) -> Value {
1542  return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1543  };
1544  Value c1 = makeConst(1);
1545  Value c2 = makeConst(2);
1546  Value c4 = makeConst(4);
1547  Value c8 = makeConst(8);
1548  Value c16 = makeConst(16);
1549  Value warpSize = makeConst(kWarpSize);
1550 
1551  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1552  return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1553  };
1554  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1555  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1556  };
1557 
1558  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1560  Type it = b.getIndexType();
1561  Value idx = b.create<arith::IndexCastOp>(it, x);
1562  Value idy0 = b.create<arith::IndexCastOp>(it, y);
1563  Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1564  Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1565  Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1566  b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1567  b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1568  };
1569 
1570  Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1571  Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1572  Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1573  Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1574  Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1575 
1576  Value tj = makeMul(lane4modId, c2);
1577  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1578  if (offset)
1579  ti = makeAdd(ti, makeConst(offset));
1580 
1581  auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1582 
1583  // Number of 32-bit registers owns per thread
1584  constexpr unsigned numAdjacentRegisters = 2;
1585  // Number of 8x8 matrices one below another per warp
1586  constexpr unsigned numStackedMatrices = 2;
1587 
1588  size_t storeCount = (structType.getBody().size() /
1589  (numStackedMatrices * numAdjacentRegisters));
1590 
1591  for (size_t i = 0; i < numStackedMatrices; ++i) {
1592  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1593  for (size_t j = 0; j < storeCount; ++j) {
1594  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1595  size_t structIndex = (i * numAdjacentRegisters) +
1596  (j * (numStackedMatrices * numAdjacentRegisters));
1597  makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1598  }
1599  }
1600  }
1601 
1602  LogicalResult
1603  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1604  ConversionPatternRewriter &rewriter) const override {
1605  int offset = 0;
1606  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1607  Value matriDValue = adaptor.getMatrixD();
1608  auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1609  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1610  auto structType = cast<LLVM::LLVMStructType>(matrixD);
1611  Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1612  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1613  offset += structType.getBody().size();
1614  }
1615  rewriter.eraseOp(op);
1616  return success();
1617  }
1618 };
1619 
1620 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1621  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1622  using ConvertOpToLLVMPattern<
1623  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1624  LogicalResult
1625  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1626  ConversionPatternRewriter &rewriter) const override {
1627  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1628  LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1629  getTypeConverter()->convertType(op.getMatrixC().getType()));
1630  Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1631  .getBody()
1632  .front();
1633  Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1634  Value packStruct = b.create<LLVM::UndefOp>(packStructType);
1635  SmallVector<Value> innerStructs;
1636  // Unpack the structs and set all values to zero
1637  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1638  auto structType = cast<LLVM::LLVMStructType>(s);
1639  Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1640  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1641  structValue = b.create<LLVM::InsertValueOp>(
1642  structType, structValue, zero, ArrayRef<int64_t>({i}));
1643  }
1644  innerStructs.push_back(structValue);
1645  }
1646  // Pack the inner structs into a single struct
1647  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1648  packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1649  packStruct, matrix, idx);
1650  }
1651  rewriter.replaceOp(op, packStruct);
1652  return success();
1653  }
1654 };
1655 
1656 struct NVGPUTmaPrefetchOpLowering
1657  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1659  LogicalResult
1660  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1661  ConversionPatternRewriter &rewriter) const override {
1662  rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1663  op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1664  return success();
1665  }
1666 };
1667 
1668 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1670  LogicalResult
1671  matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1672  ConversionPatternRewriter &rewriter) const override {
1673  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1674  auto i64Ty = b.getI64Type();
1675  auto f32Ty = b.getF32Type();
1676  VectorType inTy = op.getIn().getType();
1677  // apply rcp.approx.ftz.f on each element in vector.
1678  auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1679  Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy);
1680  int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1681  for (int i = 0; i < numElems; i++) {
1682  Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
1683  Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
1684  Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1685  ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1686  }
1687  return ret1DVec;
1688  };
1689  if (inTy.getRank() == 1) {
1690  rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1691  return success();
1692  }
1694  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1695  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1696  OpAdaptor adaptor(operands);
1697  return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1698  },
1699  rewriter);
1700  }
1701 };
1702 } // namespace
1703 
1705  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1706  patterns.add<
1707  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1708  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1709  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1710  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1711  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1712  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1713  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1714  NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1715  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1716  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1717  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1718  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1719  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1720  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1721  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1722  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1723  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1724  NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1725 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:28
constexpr int kWarpSize
Definition: NVGPUDialect.h:25
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:51
#define DBGSE()
Definition: NVGPUToNVVM.cpp:36
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:61
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:47
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
#define DBGS()
Definition: NVGPUToNVVM.cpp:35
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:105
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
FloatType getF32Type()
Definition: Builders.cpp:87
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:100
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
FloatType getF16Type()
Definition: Builders.cpp:83
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
IntegerType getI8Type()
Definition: Builders.cpp:103
FloatType getF64Type()
Definition: Builders.cpp:89
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:60
bool isTF32() const
Definition: Types.cpp:58
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isFloat8E4M3FN() const
Definition: Types.cpp:42
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:75
bool isF32() const
Definition: Types.cpp:59
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
bool isFloat8E5M2() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isBF16() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:957
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.