MLIR  22.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 #define DEBUG_TYPE "nvgpu-to-nvvm"
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define DBGSE() (llvm::dbgs())
36 
37 namespace mlir {
38 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
39 #include "mlir/Conversion/Passes.h.inc"
40 } // namespace mlir
41 
42 using namespace mlir;
43 
44 /// Number of bits that needs to be excluded when building matrix descriptor for
45 /// wgmma operations.
46 constexpr int exclude4LSB = 4;
47 
48 /// GPU has 32 bit registers, this function truncates values when larger width
49 /// is not needed.
51  Type type = value.getType();
52  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
53  if (type.getIntOrFloatBitWidth() <= 32)
54  return value;
55  return LLVM::TruncOp::create(b, b.getI32Type(), value);
56 }
57 
58 /// Returns the type for the intrinsic given the vectorResultType of the
59 /// `gpu.mma.sync` operation.
60 static Type inferIntrinsicResultType(Type vectorResultType) {
61  MLIRContext *ctx = vectorResultType.getContext();
62  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
63  auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
64  auto i32Ty = IntegerType::get(ctx, 32);
65  auto i32x2Ty = VectorType::get(2, i32Ty);
66  Type f64Ty = Float64Type::get(ctx);
67  Type f64x2Ty = VectorType::get(2, f64Ty);
68  Type f32Ty = Float32Type::get(ctx);
69  Type f32x2Ty = VectorType::get(2, f32Ty);
70  if (a.getElementType() == f16x2Ty) {
71  return LLVM::LLVMStructType::getLiteral(
72  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
73  }
74  if (a.getElementType() == i32x2Ty) {
75  return LLVM::LLVMStructType::getLiteral(
76  ctx,
77  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
78  }
79  if (a.getElementType() == f64x2Ty) {
80  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81  }
82  if (a.getElementType() == f32x2Ty) {
83  return LLVM::LLVMStructType::getLiteral(
84  ctx,
85  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
86  }
87  if (a.getElementType() == VectorType::get(1, f32Ty)) {
88  return LLVM::LLVMStructType::getLiteral(
89  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
90  }
91  return vectorResultType;
92 }
93 
94 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
95 /// always an LLVM struct) into a fragment that is compatible with the vector
96 /// type of this operation. This involves extracting elements from the struct
97 /// and inserting them into an LLVM array. These extra data-movement
98 /// operations should be canonicalized away by the LLVM backend.
99 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
100  Type resultType, Value intrinsicResult,
101  RewriterBase &rewriter) {
102  MLIRContext *ctx = rewriter.getContext();
103  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
104  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
105  Type i32Ty = rewriter.getI32Type();
106  Type f32Ty = rewriter.getF32Type();
107  Type f64Ty = rewriter.getF64Type();
108  Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
109  Type i32x2Ty = VectorType::get(2, i32Ty);
110  Type f64x2Ty = VectorType::get(2, f64Ty);
111  Type f32x2Ty = VectorType::get(2, f32Ty);
112  Type f32x1Ty = VectorType::get(1, f32Ty);
113 
114  auto makeConst = [&](int32_t index) -> Value {
115  return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
116  rewriter.getI32IntegerAttr(index));
117  };
118 
119  if (arrayType) {
120  SmallVector<Value, 4> elements;
121 
122  // The intrinsic returns 32-bit wide elements in a form which can be
123  // directly bitcasted and inserted into the result vector.
124  if (arrayType.getElementType() == f16x2Ty ||
125  arrayType.getElementType() == f32x1Ty) {
126  for (unsigned i = 0; i < structType.getBody().size(); i++) {
127  Value el =
128  LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
129  el = rewriter.createOrFold<LLVM::BitcastOp>(
130  loc, arrayType.getElementType(), el);
131  elements.push_back(el);
132  }
133  }
134 
135  // The intrinsic returns i32, f64, and f32 values as individual scalars,
136  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
137  // need to extract them from the struct and pack them into the 64-bit wide
138  // rows of the vector result.
139  if (arrayType.getElementType() == i32x2Ty ||
140  arrayType.getElementType() == f64x2Ty ||
141  arrayType.getElementType() == f32x2Ty) {
142 
143  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144  Value vec =
145  LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
146  Value x1 =
147  LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
148  Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
149  i * 2 + 1);
150  vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
151  x1, makeConst(0));
152  vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
153  x2, makeConst(1));
154  elements.push_back(vec);
155  }
156  }
157 
158  // Create the final vectorized result.
159  Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
160  for (const auto &el : llvm::enumerate(elements)) {
161  result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
162  el.index());
163  }
164  return result;
165  }
166 
167  return intrinsicResult;
168 }
169 
170 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
171 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
172 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
173 /// scalars of certain types. This function helps unpack the `vector` arguments
174 /// and cast them to the types expected by `nvvm.mma.sync`.
176  Value operand,
177  NVVM::MMATypes operandPtxType) {
178  SmallVector<Value> result;
179  Type i32Ty = b.getI32Type();
180  Type f64Ty = b.getF64Type();
181  Type f32Ty = b.getF32Type();
182  Type i64Ty = b.getI64Type();
183  Type i8x4Ty = VectorType::get(4, b.getI8Type());
184  Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
185  Type f32x1Ty = VectorType::get(1, f32Ty);
186  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
187 
188  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189  Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
190 
191  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
192  // scalar types.
193  if (arrayTy.getElementType() == i8x4Ty ||
194  arrayTy.getElementType() == i4x8Ty ||
195  (arrayTy.getElementType() == f32x1Ty &&
196  operandPtxType == NVVM::MMATypes::tf32)) {
197  result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
198  continue;
199  }
200 
201  // For some element types (i32, f32, f64), we need to unpack the inner
202  // vector/array type as well because the intrinsic expects individual
203  // scalars to be provided.
204  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
205  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
206  innerArrayTy.getElementType() == f64Ty ||
207  innerArrayTy.getElementType() == f32Ty)) {
208  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
209  idx < innerSize; idx++) {
210  result.push_back(LLVM::ExtractElementOp::create(
211  b, toUse,
212  LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
213  }
214  continue;
215  }
216  result.push_back(toUse);
217  }
218  return result;
219 }
220 
221 /// Returns whether mbarrier object has shared memory address space.
222 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
223  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
224  barrierType.getMemorySpace()));
225 }
226 
227 /// Returns the memory space attribute of the mbarrier object.
229  nvgpu::MBarrierGroupType barrierType) {
230  Attribute memorySpace = {};
231  if (isMbarrierShared(barrierType)) {
232  memorySpace =
233  IntegerAttr::get(IntegerType::get(context, 64),
234  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
235  }
236  return memorySpace;
237 }
238 
239 /// Returns memref type of the mbarrier object. The type is defined in the
240 /// MBarrierGroupType.
241 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
242  nvgpu::MBarrierGroupType barrierType) {
243  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
244  MemRefLayoutAttrInterface layout;
245  return MemRefType::get({barrierType.getNumBarriers()},
246  IntegerType::get(context, 64), layout, memorySpace);
247 }
248 
249 namespace {
250 
251 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
253 
254  LogicalResult
255  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
256  ConversionPatternRewriter &rewriter) const override {
257  MLIRContext *ctx = getContext();
258  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
259 
260  // The result type of ldmatrix will always be a struct of 32bit integer
261  // registers if more than one 32bit value is returned. Otherwise, the result
262  // is a single i32. The result type of the GPU operation is always a vector
263  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
264  // vector type of the result and always 32 bits long. We bitcast the result
265  // of the NVVM::LdMatrix to this vector type.
266  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
267  if (!vectorResultType) {
268  return failure();
269  }
270  Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
271  vectorResultType.getElementType());
272 
273  int64_t num32BitRegs = vectorResultType.getDimSize(0);
274 
275  Type ldMatrixResultType;
276  if (num32BitRegs > 1) {
277  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
278  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
279  } else {
280  ldMatrixResultType = rewriter.getI32Type();
281  }
282 
283  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
284  Value srcPtr =
285  getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
286  adaptor.getSrcMemref(), adaptor.getIndices());
287  Value ldMatrixResult = NVVM::LdMatrixOp::create(
288  b, ldMatrixResultType, srcPtr,
289  /*num=*/op.getNumTiles(),
290  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291  : NVVM::MMALayout::row);
292 
293  // The ldmatrix operation returns either a single i32 value or a struct of
294  // i32 values. Here we unpack those values and cast them back to their
295  // actual vector type (still of width 32b) and repack them into a result
296  // struct.
297  Type finalResultType = typeConverter->convertType(vectorResultType);
298  Value result = LLVM::PoisonOp::create(b, finalResultType);
299  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
300  Value i32Register =
301  num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
302  : ldMatrixResult;
303  Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
304  result = LLVM::InsertValueOp::create(b, result, casted, i);
305  }
306 
307  rewriter.replaceOp(op, result);
308  return success();
309  }
310 };
311 
312 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
313 /// enum).
314 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
315  Type elType = getElementTypeOrSelf(t);
316  if (elType.isInteger(8))
317  return NVVM::MMATypes::s8;
318  if (elType.isInteger(4))
319  return NVVM::MMATypes::s4;
320  if (elType.isF16())
321  return NVVM::MMATypes::f16;
322  if (elType.isF64())
323  return NVVM::MMATypes::f64;
324  if (elType.isF32())
325  return NVVM::MMATypes::tf32;
326  return failure();
327 }
328 
329 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331 
332  LogicalResult
333  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
334  ConversionPatternRewriter &rewriter) const override {
335  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
336  // Get the shapes of the MMAMatrix type being used. The shapes will
337  // choose which intrinsic this op will be lowered to.
338  VectorType aType = op.getMatrixA().getType();
339  VectorType bType = op.getMatrixA().getType();
340  VectorType cType = op.getMatrixC().getType();
341 
342  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
343 
344  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
345  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
346  if (aType.getElementType().isF32() && !tf32Enabled)
347  return failure();
348 
349  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
350  if (failed(ptxTypeA))
351  return op->emitOpError("failed to deduce operand PTX types");
352  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
353  if (failed(ptxTypeB))
354  return op->emitOpError("failed to deduce operand PTX types");
355  std::optional<NVVM::MMATypes> ptxTypeC =
356  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
357  /*isAccumulator=*/true);
358  if (!ptxTypeC)
359  return op->emitError(
360  "could not infer the PTX type for the accumulator/result");
361 
362  // TODO: add an attribute to the op to customize this behavior.
363  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
364  if (isa<IntegerType>(aType.getElementType()))
365  overflow = NVVM::MMAIntOverflow::satfinite;
366 
367  SmallVector<Value> matA =
368  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
369  SmallVector<Value> matB =
370  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
371  SmallVector<Value> matC =
372  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
373 
374  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
375  Type intrinsicResTy = inferIntrinsicResultType(
376  typeConverter->convertType(op->getResultTypes()[0]));
377  Value intrinsicResult =
378  NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
379  /*shape=*/gemmShape,
380  /*b1Op=*/std::nullopt,
381  /*intOverflow=*/overflow,
382  /*multiplicandPtxTypes=*/
383  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
384  /*multiplicandLayouts=*/
385  std::array<NVVM::MMALayout, 2>{
386  NVVM::MMALayout::row, NVVM::MMALayout::col});
387  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
388  desiredRetTy, intrinsicResult,
389  rewriter));
390  return success();
391  }
392 };
393 
394 struct ConvertNVGPUToNVVMPass
395  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
396  using Base::Base;
397 
398  void getDependentDialects(DialectRegistry &registry) const override {
399  registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
400  arith::ArithDialect>();
401  }
402 
403  void runOnOperation() override {
406  LLVMTypeConverter converter(&getContext(), options);
407  IRRewriter rewriter(&getContext());
409  converter, [](gpu::AddressSpace space) -> unsigned {
410  switch (space) {
411  case gpu::AddressSpace::Global:
412  return static_cast<unsigned>(
414  case gpu::AddressSpace::Workgroup:
415  return static_cast<unsigned>(
417  case gpu::AddressSpace::Private:
418  return 0;
419  }
420  llvm_unreachable("unknown address space enum value");
421  return 0;
422  });
423  /// device-side async tokens cannot be materialized in nvvm. We just
424  /// convert them to a dummy i32 type in order to easily drop them during
425  /// conversion.
426  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
427  return converter.convertType(IntegerType::get(type.getContext(), 32));
428  });
429  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
430  Type elemType = type.getFragmented().getElementType();
431  int64_t sizeM = type.getFragmented().getDimSize(0);
432  int64_t sizeN = type.getFragmented().getDimSize(1);
433 
434  unsigned numMembers;
435  if (elemType.isF32() || elemType.isInteger(32))
436  numMembers = sizeN / 2;
437  else if (elemType.isF16())
438  numMembers = sizeN / 4;
439  else
440  llvm_unreachable("unsupported type for warpgroup accumulator");
441 
442  SmallVector<Type> innerStructBody;
443  for (unsigned i = 0; i < numMembers; i++)
444  innerStructBody.push_back(elemType);
445  auto innerStructType =
446  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
447 
448  SmallVector<Type> structBody;
449  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
450  structBody.push_back(innerStructType);
451 
452  auto convertedType =
453  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
454  return converter.convertType(convertedType);
455  });
456  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
457  return converter.convertType(IntegerType::get(type.getContext(), 64));
458  });
459  converter.addConversion(
460  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
461  return converter.convertType(IntegerType::get(type.getContext(), 64));
462  });
463  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
464  return converter.convertType(
465  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
466  });
467  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
468  return LLVM::LLVMPointerType::get(type.getContext());
469  });
472  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
473  target.addLegalDialect<::mlir::arith::ArithDialect>();
474  target.addLegalDialect<::mlir::memref::MemRefDialect>();
475  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
477  converter, patterns, target);
478  if (failed(applyPartialConversion(getOperation(), target,
479  std::move(patterns))))
480  signalPassFailure();
481  }
482 };
483 
484 /// Returns the constraints for the sparse MMA inline assembly instruction.
485 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
486  unsigned matBSize,
487  unsigned matCSize) {
488  std::string str;
489  llvm::raw_string_ostream ss(str);
490  for (unsigned i = 0; i < matCSize; i++)
491  ss << "=r,";
492  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
493  ss << "r,";
494  // The final operand is for the sparsity metadata.
495  // The sparsity selector appears as direct literal.
496  ss << "r";
497  return str;
498 }
499 
500 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
501 /// the given parameters. Note that this function doesn't do any validation,
502 /// it's expected that the provided parameters correspond to a valid
503 /// instruction.
504 static std::string buildMmaSparseAsmString(
505  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
506  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
507  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
508  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
509  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
510  return NVVM::stringifyMMATypes(ptxType);
511  };
512 
513  std::string asmStr;
514  llvm::raw_string_ostream ss(asmStr);
515  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
516  << shape[2] << ".row.col.";
517 
518  if (overflow)
519  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
520 
521  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
522  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
523  unsigned asmArgIdx = 0;
524 
525  // The operand string is structured into sections `{matC elements...},
526  // {matA elements...}, {matB elements...}, {matC elements}`.
527  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
528  ss << "{";
529  for (unsigned i = 0; i < arrSize; i++)
530  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
531  ss << "},";
532  }
533  ss << "$" << asmArgIdx++ << ",";
534  assert(metaDataSelector <= 1);
535  ss << "0x" << metaDataSelector << ";";
536  return asmStr;
537 }
538 
539 /// Builds an inline assembly operation corresponding to the specified MMA
540 /// sparse sync operation.
541 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
542  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
543  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
544  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
545  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
546  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
547  Type intrinsicResultType) {
548  auto asmDialectAttr =
549  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
550 
551  const unsigned matASize = unpackedAData.size();
552  const unsigned matBSize = unpackedB.size();
553  const unsigned matCSize = unpackedC.size();
554 
555  std::string asmStr = buildMmaSparseAsmString(
556  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
557  ptxTypeD, overflow, metadataSelector);
558  std::string constraintStr =
559  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
560 
561  SmallVector<Value> asmVals;
562  asmVals.reserve(matASize + matBSize + matCSize + 1);
563  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
564  llvm::append_range(asmVals, args);
565  asmVals.push_back(indexData);
566 
567  return LLVM::InlineAsmOp::create(b,
568  /*resultTypes=*/intrinsicResultType,
569  /*operands=*/asmVals,
570  /*asm_string=*/asmStr,
571  /*constraints=*/constraintStr,
572  /*has_side_effects=*/true,
573  /*is_align_stack=*/false,
575  /*asm_dialect=*/asmDialectAttr,
576  /*operand_attrs=*/ArrayAttr());
577 }
578 
579 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
580 struct NVGPUMmaSparseSyncLowering
581  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
583 
584  LogicalResult
585  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
586  ConversionPatternRewriter &rewriter) const override {
587  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
588  // Get the shapes of the MMAMatrix type being used. The shapes will
589  // choose which intrinsic this op will be lowered to.
590  VectorType aType = op.getMatrixA().getType();
591  VectorType bType = op.getMatrixB().getType();
592  VectorType cType = op.getMatrixC().getType();
593 
594  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595  if (failed(ptxTypeA))
596  return op->emitOpError("failed to deduce operand PTX types");
597  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598  if (failed(ptxTypeB))
599  return op->emitOpError("failed to deduce operand PTX types");
600  std::optional<NVVM::MMATypes> ptxTypeC =
601  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
602  /*isAccumulator=*/true);
603  if (!ptxTypeC)
604  return op->emitError(
605  "could not infer the PTX type for the accumulator/result");
606 
607  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
608  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609  if (aType.getElementType().isF32() && !tf32Enabled)
610  return failure();
611 
612  // TODO: add an attribute to the op to customize this behavior.
613  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614  if (isa<IntegerType>(aType.getElementType()))
615  overflow = NVVM::MMAIntOverflow::satfinite;
616 
617  SmallVector<Value> matA =
618  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
619  SmallVector<Value> matB =
620  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
621  SmallVector<Value> matC =
622  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
623 
624  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
625  Type intrinsicResTy = inferIntrinsicResultType(
626  typeConverter->convertType(op->getResultTypes()[0]));
627 
628  // Bitcast the sparse metadata from vector<2xf16> to an i32.
629  Value sparseMetadata = adaptor.getSparseMetadata();
630  if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
631  return op->emitOpError() << "Expected metadata type to be LLVM "
632  "VectorType of 2 i16 elements";
633  sparseMetadata =
634  LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
635 
636  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
637  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
638  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
639  intrinsicResTy);
640  if (failed(intrinsicResult))
641  return failure();
642 
643  assert((*intrinsicResult).getNumResults() == 1 &&
644  "expected inline asm op returns a single LLVM struct type");
645  rewriter.replaceOp(
646  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
647  (*intrinsicResult)->getResult(0), rewriter));
648  return success();
649  }
650 };
651 
652 struct NVGPUAsyncCopyLowering
653  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
655  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
656 
657  LogicalResult
658  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
659  ConversionPatternRewriter &rewriter) const override {
660  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
661  Location loc = op.getLoc();
662  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
663  Value dstPtr =
664  getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
665  adaptor.getDst(), adaptor.getDstIndices());
666  FailureOr<unsigned> dstAddressSpace =
667  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
668  if (failed(dstAddressSpace))
669  return rewriter.notifyMatchFailure(
670  loc, "destination memref address space not convertible to integer");
671 
672  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
673  FailureOr<unsigned> srcAddressSpace =
674  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
675  if (failed(srcAddressSpace))
676  return rewriter.notifyMatchFailure(
677  loc, "source memref address space not convertible to integer");
678 
679  Value scrPtr =
680  getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
681  adaptor.getSrcIndices());
682  // Intrinsics takes a global pointer so we need an address space cast.
683  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
685  scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
686  int64_t dstElements = adaptor.getDstElements().getZExtValue();
687  int64_t sizeInBytes =
688  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
689  // When the optional SrcElements argument is *not* present, the regular
690  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
691  // memory) to fill DstElements number of elements in the destination
692  // (shared memory).
693  Value srcBytes = adaptor.getSrcElements();
694  if (srcBytes) {
695  // When the optional SrcElements argument is present, the source (global
696  // memory) of CpAsyncOp is read only for SrcElements number of elements.
697  // The rest of the DstElements in the destination (shared memory) are
698  // filled with zeros.
699  Value c3I32 =
700  LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
701  Value bitwidth = LLVM::ConstantOp::create(
702  b, b.getI32Type(),
703  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
704  Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
705  srcBytes = LLVM::LShrOp::create(
706  b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
707  }
708  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
709  // 16 dst bytes.
710  NVVM::LoadCacheModifierKind cacheModifier =
711  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
712  ? NVVM::LoadCacheModifierKind::CG
713  : NVVM::LoadCacheModifierKind::CA;
714 
715  NVVM::CpAsyncOp::create(
716  b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
717  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
718  srcBytes);
719 
720  // Drop the result token.
721  Value zero =
722  LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
723  rewriter.getI32IntegerAttr(0));
724  rewriter.replaceOp(op, zero);
725  return success();
726  }
727 };
728 
729 struct NVGPUAsyncCreateGroupLowering
730  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
732  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
733 
734  LogicalResult
735  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
736  ConversionPatternRewriter &rewriter) const override {
737  NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
738  // Drop the result token.
739  Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
740  IntegerType::get(op.getContext(), 32),
741  rewriter.getI32IntegerAttr(0));
742  rewriter.replaceOp(op, zero);
743  return success();
744  }
745 };
746 
747 struct NVGPUAsyncWaitLowering
748  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
750  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
751 
752  LogicalResult
753  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
754  ConversionPatternRewriter &rewriter) const override {
755  // If numGroup is not present pick 0 as a conservative correct value.
756  int32_t numGroups = adaptor.getNumGroups().value_or(0);
757  NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
758  rewriter.eraseOp(op);
759  return success();
760  }
761 };
762 
763 /// Creates mbarrier object in shared memory
764 struct NVGPUMBarrierCreateLowering
765  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
767 
768  template <typename moduleT>
769  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
770  Operation *funcOp, moduleT moduleOp,
771  MemRefType barrierType) const {
772  SymbolTable symbolTable(moduleOp);
773  OpBuilder::InsertionGuard guard(rewriter);
774  rewriter.setInsertionPoint(&moduleOp.front());
775  auto global = memref::GlobalOp::create(
776  rewriter, funcOp->getLoc(), "__mbarrier",
777  /*sym_visibility=*/rewriter.getStringAttr("private"),
778  /*type=*/barrierType,
779  /*initial_value=*/ElementsAttr(),
780  /*constant=*/false,
781  /*alignment=*/rewriter.getI64IntegerAttr(8));
782  symbolTable.insert(global);
783  return global;
784  }
785 
786  LogicalResult
787  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
788  ConversionPatternRewriter &rewriter) const override {
789  Operation *funcOp = op->getParentOp();
790  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
791  rewriter.getContext(), op.getBarriers().getType());
792 
793  memref::GlobalOp global;
794  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
795  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
796  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
797  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
798 
799  rewriter.setInsertionPoint(op);
800  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
801  global.getName());
802  return success();
803  }
804 };
805 
806 /// Base class for lowering mbarrier operations to nvvm intrinsics.
807 template <typename SourceOp>
808 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
809 public:
811  /// Returns the base pointer of the mbarrier object.
812  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
813  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
814  Value mbarId,
815  ConversionPatternRewriter &rewriter) const {
816  MemRefType mbarrierMemrefType =
817  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
819  rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
820  }
821 };
822 
823 struct NVGPUMBarrierGetLowering
824  : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
825  using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
826 
827  LogicalResult
828  matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
829  ConversionPatternRewriter &rewriter) const override {
830  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
831  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
832  rewriter.setInsertionPoint(op);
833  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
834  adaptor.getMbarId(), rewriter);
835  Type resType = op.getMbarrierPointer().getType();
836  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
837  return success();
838  }
839 };
840 
841 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
842 struct NVGPUMBarrierInitLowering
843  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
844  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
845 
846  LogicalResult
847  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
848  ConversionPatternRewriter &rewriter) const override {
849  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
850  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
851  rewriter.setInsertionPoint(op);
852  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
853  adaptor.getMbarId(), rewriter);
854  Value count = truncToI32(b, adaptor.getCount());
855  if (isMbarrierShared(mbarrierType)) {
856  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
857  op, barrier, count, adaptor.getPredicate());
858  } else {
859  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
860  adaptor.getPredicate());
861  }
862  return success();
863  }
864 };
865 
866 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
867 struct NVGPUMBarrierArriveLowering
868  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
869  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
870  LogicalResult
871  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
872  ConversionPatternRewriter &rewriter) const override {
873  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
874  Value barrier =
875  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
876  adaptor.getMbarId(), rewriter);
877  Type tokenType = getTypeConverter()->convertType(
878  nvgpu::MBarrierTokenType::get(op->getContext()));
879  if (isMbarrierShared(op.getBarriers().getType())) {
880  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
881  barrier);
882  } else {
883  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
884  barrier);
885  }
886  return success();
887  }
888 };
889 
890 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
891 /// `nvvm.mbarrier.arrive.nocomplete`
892 struct NVGPUMBarrierArriveNoCompleteLowering
893  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
894  using MBarrierBasePattern<
895  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
896  LogicalResult
897  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
898  ConversionPatternRewriter &rewriter) const override {
899  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
900  Value barrier =
901  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
902  adaptor.getMbarId(), rewriter);
903  Type tokenType = getTypeConverter()->convertType(
904  nvgpu::MBarrierTokenType::get(op->getContext()));
905  Value count = truncToI32(b, adaptor.getCount());
906  if (isMbarrierShared(op.getBarriers().getType())) {
907  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
908  op, tokenType, barrier, count);
909  } else {
910  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
911  op, tokenType, barrier, count);
912  }
913  return success();
914  }
915 };
916 
917 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
918 struct NVGPUMBarrierTestWaitLowering
919  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
920  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
921  LogicalResult
922  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
923  ConversionPatternRewriter &rewriter) const override {
924  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
925  Value barrier =
926  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
927  adaptor.getMbarId(), rewriter);
928  Type retType = rewriter.getI1Type();
929  if (isMbarrierShared(op.getBarriers().getType())) {
930  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
931  op, retType, barrier, adaptor.getToken());
932  } else {
933  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
934  op, retType, barrier, adaptor.getToken());
935  }
936  return success();
937  }
938 };
939 
940 struct NVGPUMBarrierArriveExpectTxLowering
941  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
942  using MBarrierBasePattern<
943  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
944  LogicalResult
945  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
946  ConversionPatternRewriter &rewriter) const override {
947  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
948  Value barrier =
949  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
950  adaptor.getMbarId(), rewriter);
951  Value txcount = truncToI32(b, adaptor.getTxcount());
952 
953  if (isMbarrierShared(op.getBarriers().getType())) {
954  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
955  op, barrier, txcount, adaptor.getPredicate());
956  return success();
957  }
958 
959  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
960  op, barrier, txcount, adaptor.getPredicate());
961  return success();
962  }
963 };
964 
965 struct NVGPUMBarrierTryWaitParityLowering
966  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
967  using MBarrierBasePattern<
968  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
969  LogicalResult
970  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
971  ConversionPatternRewriter &rewriter) const override {
972  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
973  Value barrier =
974  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
975  adaptor.getMbarId(), rewriter);
976  Value ticks = truncToI32(b, adaptor.getTicks());
977  Value phase =
978  LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
979 
980  if (isMbarrierShared(op.getBarriers().getType())) {
981  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
982  op, barrier, phase, ticks);
983  return success();
984  }
985 
986  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
987  phase, ticks);
988  return success();
989  }
990 };
991 
992 struct NVGPUTmaAsyncLoadOpLowering
993  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
994  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
995  LogicalResult
996  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
997  ConversionPatternRewriter &rewriter) const override {
998  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
999  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
1000  Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1001  adaptor.getDst(), {});
1002  Value barrier =
1003  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1004  adaptor.getMbarId(), rewriter);
1005 
1006  SmallVector<Value> coords = adaptor.getCoordinates();
1007  for (auto [index, value] : llvm::enumerate(coords)) {
1008  coords[index] = truncToI32(b, value);
1009  }
1010  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
1011  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1012  ValueRange{}, adaptor.getMulticastMask(), Value{},
1013  adaptor.getPredicate());
1014  return success();
1015  }
1016 };
1017 
1018 struct NVGPUTmaAsyncStoreOpLowering
1019  : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1020  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1021  LogicalResult
1022  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1023  ConversionPatternRewriter &rewriter) const override {
1024  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1025  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1026  Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1027  adaptor.getSrc(), {});
1028  SmallVector<Value> coords = adaptor.getCoordinates();
1029  for (auto [index, value] : llvm::enumerate(coords)) {
1030  coords[index] = truncToI32(b, value);
1031  }
1032 
1033  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1034  op, adaptor.getTensorMapDescriptor(), dest, coords,
1035  adaptor.getPredicate());
1036  return success();
1037  }
1038 };
1039 
1040 struct NVGPUGenerateWarpgroupDescriptorLowering
1041  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1042  using ConvertOpToLLVMPattern<
1043  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1044 
1045  LogicalResult
1046  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1047  ConversionPatternRewriter &rewriter) const override {
1048 
1049  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1050 
1051  nvgpu::TensorMapSwizzleKind swizzleKind =
1052  op.getTensorMap().getType().getSwizzle();
1053 
1054  unsigned layout =
1055  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1056  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1057  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1058  : 1;
1059  unsigned swizzle =
1060  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1061  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1062  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1063  : 0;
1064 
1065  auto ti64 = b.getIntegerType(64);
1066  auto makeConst = [&](uint64_t index) -> Value {
1067  return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
1068  };
1069  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1070  return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1071  };
1072  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1073  return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1074  };
1075  auto insertBit = [&](Value desc, Value val, int startBit) {
1076  return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1077  };
1078 
1079  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1080  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1081  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1082  uint64_t offsetVal = 0;
1083 
1084  Value strideDim = makeConst(strideDimVal);
1085  Value leadDim = makeConst(leadDimVal);
1086 
1087  Value baseAddr = getStridedElementPtr(
1088  rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1089  adaptor.getTensor(), {});
1090  Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1091  // Just use 14 bits for base address
1092  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1093 
1094  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1095  startLeadBit = 16, startBaseAddrBit = 0;
1096  Value dsc = makeConst(0);
1097  // // [62,64) swizzle type
1098  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1099  // // [49,52) base_offset
1100  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1101  // // [32,46) stride
1102  dsc = insertBit(dsc, strideDim, startStrideBit);
1103  // // [16,30) leading dimension
1104  dsc = insertBit(dsc, leadDim, startLeadBit);
1105  // // [0,14) start_address
1106  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1107 
1108  LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1109  << "leading_off:" << leadDimVal << "\t"
1110  << "stride_off :" << strideDimVal << "\t"
1111  << "base_offset:" << offsetVal << "\t"
1112  << "layout_type:" << swizzle << " ("
1113  << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1114  << ")\n start_addr : " << baseAddr << "\n");
1115 
1116  rewriter.replaceOp(op, dsc);
1117  return success();
1118  }
1119 };
1120 
1121 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1122  return LLVM::ConstantOp::create(b, b.getIntegerType(64),
1123  b.getI32IntegerAttr(index));
1124 }
1125 
1126 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1127 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1128  // Enum is from CUDA driver API
1129  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1130  enum CUtensorMapDataTypeEnum {
1131  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1132  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1133  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1134  CU_TENSOR_MAP_DATA_TYPE_INT32,
1135  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1136  CU_TENSOR_MAP_DATA_TYPE_INT64,
1137  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1138  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1139  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1140  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1141  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1142  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1143  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1144  };
1145 
1146  if (type.isUnsignedInteger(8))
1147  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1148  if (type.isUnsignedInteger(16))
1149  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1150  if (type.isUnsignedInteger(32))
1151  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1152  if (type.isUnsignedInteger(64))
1153  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1154  if (type.isSignlessInteger(32))
1155  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1156  if (type.isSignlessInteger(64))
1157  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1158  if (type.isF16())
1159  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1160  if (type.isF32())
1161  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1162  if (type.isF64())
1163  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1164  if (type.isBF16())
1165  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1166 
1167  llvm_unreachable("Not supported data type");
1168 }
1169 
1170 struct NVGPUTmaCreateDescriptorOpLowering
1171  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1172  using ConvertOpToLLVMPattern<
1173  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1174  LogicalResult
1175  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1176  ConversionPatternRewriter &rewriter) const override {
1177  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1178  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1179  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1180 
1181  Value tensorElementType =
1182  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1183  auto promotedOperands = getTypeConverter()->promoteOperands(
1184  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1185 
1186  Value boxArrayPtr = LLVM::AllocaOp::create(
1187  b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1188  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1189  Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1190  boxArrayPtr, makeI64Const(b, index));
1191  LLVM::StoreOp::create(b, value, gep);
1192  }
1193 
1194  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1195  // Set Arguments for the function call
1196  SmallVector<Value> arguments;
1197  arguments.push_back(promotedOperands[0]); // rank
1198  arguments.push_back(promotedOperands[1]); // descriptor
1199  arguments.push_back(tensorElementType); // data type
1200  arguments.push_back(
1201  makeI64Const(b, (int)desc.getInterleave())); // interleave
1202  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1203  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1204  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1205  arguments.push_back(boxArrayPtr); // box dimensions
1206 
1207  // Set data types of the arguments
1208  SmallVector<Type> argTypes = {
1209  llvmInt64Type, /* int64_t tensorRank */
1210  llvmPointerType, /* ptr */
1211  llvmInt64Type, /* int64_t */
1212  llvmInt64Type, /* int64_t */
1213  llvmInt64Type, /* int64_t */
1214  llvmInt64Type, /* int64_t */
1215  llvmInt64Type, /* int64_t */
1216  llvmPointerType /* ptr */
1217  };
1218  FunctionCallBuilder hostRegisterCallBuilder = {
1219  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1220  Value tensorMap =
1221  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1222 
1223  rewriter.replaceOp(op, tensorMap);
1224  return success();
1225  }
1226 };
1227 
1228 struct NVGPUWarpgroupMmaOpLowering
1229  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1231 
1232  /// This is a helper class to generate required NVVM Ops for warp-group level
1233  /// matrix multiplication.
1234  /// When the given GEMM shape is larger than the shape of
1235  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1236  /// Op(s), group and execute them asynchronously. The class also handles
1237  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1238  /// create descriptors for each instruction.
1239  ///
1240  /// For example this is the case when the shape of GEMM is 128x128x128
1241  ///
1242  /// nvvm.wgmma.fence.aligned
1243  ///
1244  /// nvvm.wgmma.mma.async descA, descB
1245  /// iterate(descA, descB)
1246  /// nvvm.wgmma.mma.async descA, descB
1247  /// [6x times more]
1248  ///
1249  /// nvvm.wgmma.group.sync.aligned
1250  /// nvvm.wgmma.wait.group.sync [groupId]
1251  ///
1252  class WarpgroupGemm {
1253  nvgpu::WarpgroupMmaOp op;
1255  OpAdaptor adaptor;
1256 
1257  // Entire shape of the given Op
1258  int64_t totalM, totalN, totalK;
1259 
1260  // Shape of one wgmma instruction
1261  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1262 
1263  // Iteration counts for GEMM
1264  int iterationM = 0, iterationN = 0, iterationK = 0;
1265 
1266  /// The function returns the shape of wgmma instruction that is defined in
1267  /// PTX programming guide.
1268  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1269  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1270  wgmmaM = 64;
1271  wgmmaN = sizeN;
1272  if (inputElemType.isTF32()) {
1273  wgmmaK = 8;
1274  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1275  wgmmaK = 16;
1276  } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1277  inputElemType.isInteger(16)) {
1278  wgmmaK = 32;
1279  } else if (inputElemType.isInteger(1)) {
1280  wgmmaK = 256;
1281  } else {
1282  llvm_unreachable("msg: not supported K shape");
1283  }
1284  LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1285  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1286  }
1287 
1288  /// Generates WGMMATypesAttr from MLIR Type
1289  NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1290  bool useF32 = false) const {
1291  auto getWgmmaType = [=](Type elemType) {
1292  if (elemType.isF32() || elemType.isTF32())
1293  return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1294  if (elemType.isF16())
1295  return NVVM::WGMMATypes::f16;
1296  if (elemType.isBF16())
1297  return NVVM::WGMMATypes::bf16;
1298  if (isa<Float8E4M3FNType>(elemType))
1299  return NVVM::WGMMATypes::e4m3;
1300  if (isa<Float8E5M2Type>(elemType))
1301  return NVVM::WGMMATypes::e5m2;
1302  if (elemType.isInteger(1))
1303  return NVVM::WGMMATypes::b1;
1304  if (elemType.isInteger(8))
1305  return NVVM::WGMMATypes::s8;
1306  if (elemType.isUnsignedInteger(8))
1307  return NVVM::WGMMATypes::u8;
1308  if (elemType.isInteger(32))
1309  return NVVM::WGMMATypes::s32;
1310  llvm_unreachable("unsupported type");
1311  };
1312  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1313  }
1314 
1315  /// Generates layout attribute for the input matrix for wgmma instruction
1316  NVVM::MMALayoutAttr
1317  generateWgmmaLayout(std::optional<bool> transpose) const {
1318  if (transpose.value_or(false))
1319  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1320  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1321  }
1322 
1323  /// Generates shape attribute for wgmma instruction
1324  NVVM::MMAShapeAttr generateWgmmaShape() const {
1325  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1326  }
1327 
1328  /// Generates scale attributes of output matrix for wgmma instruction
1329  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1330  return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1331  NVVM::WGMMAScaleOut::one);
1332  }
1333  /// Generates scale attributes of input matrix for wgmma instruction
1334  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1335  return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1336  NVVM::WGMMAScaleIn::one);
1337  }
1338 
1339  /// Basic function to generate Add
1340  Value makeAdd(Value lhs, Value rhs) {
1341  return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1342  };
1343 
1344  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1345  /// Currently, it only handles row-major.
1346  ///
1347  /// It moves the pointer like below for [128][64] size:
1348  /// +2 +4 +6
1349  /// ↓ ↓ ↓
1350  /// descA ---> +--+--+--+--+
1351  /// |->|->|->|->|
1352  /// | | | | |
1353  /// | | | | |
1354  /// | | | | |
1355  /// descA+512---> +-----------+
1356  /// | | | | |
1357  /// | | | | |
1358  /// | | | | |
1359  /// | | | | |
1360  /// +-----------+
1361  ///
1362  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1363  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1364  Type elemA = matrixTypeA.getElementType();
1365  int byte = elemA.getIntOrFloatBitWidth() / 8;
1366  int tileShapeA = matrixTypeA.getDimSize(1);
1367  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1368  incrementVal = incrementVal >> exclude4LSB;
1369  LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1370  << "] [wgmma descriptors] Descriptor A + "
1371  << incrementVal << " | \t ");
1372  if (!incrementVal)
1373  return desc;
1374  return makeAdd(desc, makeI64Const(b, incrementVal));
1375  }
1376 
1377  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1378  /// Currently, it only handles column-major.
1379  ///
1380  /// It moves the pointer like below for [128][64] size:
1381  /// descB ---> +--+--+--+--+--+--+--+--+
1382  /// |↓ | | | | | | | |
1383  /// |↓ | | | | | | | |
1384  /// |↓ | | | | | | | |
1385  /// |↓ | | | | | | | |
1386  /// +--+--+--+--+--+--+--+--+
1387  ///
1388  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1389  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1390  Type elemB = matrixTypeB.getElementType();
1391  int byte = elemB.getIntOrFloatBitWidth() / 8;
1392  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1393  incrementVal = incrementVal >> exclude4LSB;
1394  LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1395  if (!incrementVal)
1396  return desc;
1397  return makeAdd(desc, makeI64Const(b, incrementVal));
1398  }
1399 
1400  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1401  /// descriptors and arranges them based on induction variables: i, j, and k.
1402  Value generateWgmma(int i, int j, int k, Value matrixC) {
1403  LLVM_DEBUG(DBGS() << "\t wgmma."
1404  << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1405  << "(A[" << (iterationM * wgmmaM) << ":"
1406  << (iterationM * wgmmaM) + wgmmaM << "]["
1407  << (iterationK * wgmmaK) << ":"
1408  << (iterationK * wgmmaK + wgmmaK) << "] * "
1409  << " B[" << (iterationK * wgmmaK) << ":"
1410  << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1411  << wgmmaN << "])\n");
1412 
1413  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1414  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1415 
1416  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1417  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1418 
1419  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1420  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1421 
1422  Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1423  NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1424 
1425  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1426  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1427  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1428  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1429  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1430 
1431  auto overflow = NVVM::MMAIntOverflowAttr::get(
1432  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1433 
1434  return NVVM::WgmmaMmaAsyncOp::create(
1435  b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
1436  itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1437  overflow);
1438  }
1439 
1440  /// Generates multiple wgmma instructions to complete the given GEMM shape
1441  Value generateWgmmaGroup() {
1442  Value wgmmaResult =
1443  LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1444 
1445  // Perform GEMM
1446  SmallVector<Value> wgmmaResults;
1447  for (int i = 0; i < iterationM; ++i) {
1448  Value matrixC =
1449  LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1450  for (int j = 0; j < iterationN; ++j)
1451  for (int k = 0; k < iterationK; ++k)
1452  matrixC = generateWgmma(i, j, k, matrixC);
1453  wgmmaResults.push_back(matrixC);
1454  }
1455  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1456  wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
1457  wgmmaResult, matrix, idx);
1458  }
1459  return wgmmaResult;
1460  }
1461 
1462  public:
1463  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1464  OpAdaptor adaptor)
1465  : op(op), b(b), adaptor(adaptor) {
1466  // Find the entire GEMM Shape
1467  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1468  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1469  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1470  LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1471  << "] += A[" << totalM << "][" << totalK << "] * B["
1472  << totalK << "][" << totalN << "] ---===\n");
1473 
1474  // Find the shape for one wgmma instruction
1475  findWgmmaShape(
1476  totalM, totalN,
1477  op.getDescriptorA().getType().getTensor().getElementType());
1478 
1479  // Iterations counts to complete the given shape with wgmma shape
1480  iterationM = totalM / wgmmaM;
1481  iterationN = totalN / wgmmaN;
1482  iterationK = totalK / wgmmaK;
1483  }
1484 
1485  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1486  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1487  /// instructions and group synchronization, as well as waiting
1488  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1489  /// (WgmmaWaitGroupSyncOp) after the instructions.
1490  Value generateWarpgroupMma() {
1491  NVVM::WgmmaFenceAlignedOp::create(b);
1492  Value wgmmaResult = generateWgmmaGroup();
1493  NVVM::WgmmaGroupSyncAlignedOp::create(b);
1494  NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1495  return wgmmaResult;
1496  }
1497  };
1498  LogicalResult
1499  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1500  ConversionPatternRewriter &rewriter) const override {
1501  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1502 
1503  // Step 1. Build a helper class
1504  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1505 
1506  // Step 2. Get the entire GEMM Shape
1507  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1508 
1509  // Step 3. Replace fragmented result struct with the op results
1510  rewriter.replaceOp(op, wgmmaResult);
1511  return success();
1512  }
1513 };
1514 
1515 struct NVGPUWarpgroupMmaStoreOpLowering
1516  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1517  using ConvertOpToLLVMPattern<
1518  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1519 
1520  /// This function stores a fragmented register matrix owned by a warp group
1521  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1522  /// of a struct.
1523  /// Here is what each threads (T) holds, each `d` is struct value with a
1524  /// number.
1525  ///
1526  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1527  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1528  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1529  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1530  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1531  ///
1532  /// Matrix-D:
1533  /// +______________________________________________________________________+
1534  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1535  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1536  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1537  /// ..| .........|.........|.........|.........|........|...........|........|
1538  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1539  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1540  /// ..| .........|.........|.........|.........|........|...........|........|
1541  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1542  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1543  /// ..| .........|.........|.........|.........|........|...........|........|
1544  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1545  /// ..| .........|.........|.........|.........|........|...........|........|
1546  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1547  /// ..| .........|.........|.........|.........|........|...........|........|
1548  /// +______________________________________________________________________+
1549  ///
1550  /// \param rewriter: The pattern rewriter.
1551  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1552  /// matrix). It is holded by a thread and a struct with 64 elements.
1553  /// \param dstMemref: The memref where the registers will be stored.
1554  /// \param offset: the offset within the memref where the registers will be
1555  /// stored.
1556  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1557  TypedValue<MemRefType> dstMemref,
1558  int offset) const {
1559  Type i32 = b.getI32Type();
1560 
1561  auto makeConst = [&](int32_t index) -> Value {
1562  return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
1563  };
1564  Value c1 = makeConst(1);
1565  Value c2 = makeConst(2);
1566  Value c4 = makeConst(4);
1567  Value c8 = makeConst(8);
1568  Value c16 = makeConst(16);
1569  Value warpSize = makeConst(kWarpSize);
1570 
1571  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1572  return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
1573  };
1574  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1575  return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1576  };
1577 
1578  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1580  Type it = b.getIndexType();
1581  Value idx = arith::IndexCastOp::create(b, it, x);
1582  Value idy0 = arith::IndexCastOp::create(b, it, y);
1583  Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1584  Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1585  Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1586  memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
1587  memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
1588  };
1589 
1590  Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1591  Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1592  Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1593  Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1594  Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1595 
1596  Value tj = makeMul(lane4modId, c2);
1597  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1598  if (offset)
1599  ti = makeAdd(ti, makeConst(offset));
1600 
1601  auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1602 
1603  // Number of 32-bit registers owns per thread
1604  constexpr unsigned numAdjacentRegisters = 2;
1605  // Number of 8x8 matrices one below another per warp
1606  constexpr unsigned numStackedMatrices = 2;
1607 
1608  size_t storeCount = (structType.getBody().size() /
1609  (numStackedMatrices * numAdjacentRegisters));
1610 
1611  for (size_t i = 0; i < numStackedMatrices; ++i) {
1612  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1613  for (size_t j = 0; j < storeCount; ++j) {
1614  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1615  size_t structIndex = (i * numAdjacentRegisters) +
1616  (j * (numStackedMatrices * numAdjacentRegisters));
1617  makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1618  }
1619  }
1620  }
1621 
1622  LogicalResult
1623  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1624  ConversionPatternRewriter &rewriter) const override {
1625  int offset = 0;
1626  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1627  Value matriDValue = adaptor.getMatrixD();
1628  auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1629  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1630  auto structType = cast<LLVM::LLVMStructType>(matrixD);
1631  Value innerStructValue =
1632  LLVM::ExtractValueOp::create(b, matriDValue, idx);
1633  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1634  offset += structType.getBody().size();
1635  }
1636  rewriter.eraseOp(op);
1637  return success();
1638  }
1639 };
1640 
1641 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1642  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1643  using ConvertOpToLLVMPattern<
1644  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1645  LogicalResult
1646  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1647  ConversionPatternRewriter &rewriter) const override {
1648  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1649  LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1650  getTypeConverter()->convertType(op.getMatrixC().getType()));
1651  Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1652  .getBody()
1653  .front();
1654  Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
1655  Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1656  SmallVector<Value> innerStructs;
1657  // Unpack the structs and set all values to zero
1658  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1659  auto structType = cast<LLVM::LLVMStructType>(s);
1660  Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1661  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1662  structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1663  zero, ArrayRef<int64_t>({i}));
1664  }
1665  innerStructs.push_back(structValue);
1666  }
1667  // Pack the inner structs into a single struct
1668  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1669  packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
1670  packStruct, matrix, idx);
1671  }
1672  rewriter.replaceOp(op, packStruct);
1673  return success();
1674  }
1675 };
1676 
1677 struct NVGPUTmaFenceOpLowering
1678  : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1680  LogicalResult
1681  matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1682  ConversionPatternRewriter &rewriter) const override {
1683  MLIRContext *ctx = op.getContext();
1684  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1685  auto i32Ty = b.getI32Type();
1686  Value tensormapSize =
1687  LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
1688 
1689  auto memscope =
1690  NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1691 
1692  rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1693  op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1694 
1695  return success();
1696  }
1697 };
1698 
1699 struct NVGPUTmaPrefetchOpLowering
1700  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1702  LogicalResult
1703  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1704  ConversionPatternRewriter &rewriter) const override {
1705  rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1706  op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1707  return success();
1708  }
1709 };
1710 
1711 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1713  LogicalResult
1714  matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1715  ConversionPatternRewriter &rewriter) const override {
1716  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1717  auto i64Ty = b.getI64Type();
1718  auto f32Ty = b.getF32Type();
1719  VectorType inTy = op.getIn().getType();
1720  // apply rcp.approx.ftz.f on each element in vector.
1721  auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1722  Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1723  int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1724  for (int i = 0; i < numElems; i++) {
1725  Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
1726  Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1727  Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1728  ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1729  }
1730  return ret1DVec;
1731  };
1732  if (inTy.getRank() == 1) {
1733  rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1734  return success();
1735  }
1737  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1738  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1739  OpAdaptor adaptor(operands);
1740  return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1741  },
1742  rewriter);
1743  }
1744 };
1745 } // namespace
1746 
1748  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1749  patterns.add<
1750  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1751  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1752  NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1753  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1754  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1755  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1756  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1757  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1758  NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1759  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1760  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1761  NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1762  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1763  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1764  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1765  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1766  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1767  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1768  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1769  NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1770 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
@ None
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:50
#define DBGSE()
Definition: NVGPUToNVVM.cpp:35
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:60
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:46
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
#define DBGS()
Definition: NVGPUToNVVM.cpp:34
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
Definition: NVGPUToNVVM.cpp:99
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:60
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
FloatType getF16Type()
Definition: Builders.cpp:38
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
IntegerType getI8Type()
Definition: Builders.cpp:58
FloatType getF64Type()
Definition: Builders.cpp:44
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition: Pattern.cpp:64
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:750
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:654
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isTF32() const
Definition: Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:489
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:42
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.