MLIR  21.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
37 
38 namespace mlir {
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 /// Number of bits that needs to be excluded when building matrix descriptor for
46 /// wgmma operations.
47 constexpr int exclude4LSB = 4;
48 
49 /// GPU has 32 bit registers, this function truncates values when larger width
50 /// is not needed.
52  Type type = value.getType();
53  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
54  if (type.getIntOrFloatBitWidth() <= 32)
55  return value;
56  return b.create<LLVM::TruncOp>(b.getI32Type(), value);
57 }
58 
59 /// Returns the type for the intrinsic given the vectorResultType of the
60 /// `gpu.mma.sync` operation.
61 static Type inferIntrinsicResultType(Type vectorResultType) {
62  MLIRContext *ctx = vectorResultType.getContext();
63  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
64  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
65  auto i32Ty = IntegerType::get(ctx, 32);
66  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
67  Type f64Ty = Float64Type::get(ctx);
68  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
69  Type f32Ty = Float32Type::get(ctx);
70  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
71  if (a.getElementType() == f16x2Ty) {
72  return LLVM::LLVMStructType::getLiteral(
73  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
74  }
75  if (a.getElementType() == i32x2Ty) {
76  return LLVM::LLVMStructType::getLiteral(
77  ctx,
78  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
79  }
80  if (a.getElementType() == f64x2Ty) {
81  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
82  }
83  if (a.getElementType() == f32x2Ty) {
84  return LLVM::LLVMStructType::getLiteral(
85  ctx,
86  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
87  }
88  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
89  return LLVM::LLVMStructType::getLiteral(
90  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
91  }
92  return vectorResultType;
93 }
94 
95 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
96 /// always an LLVM struct) into a fragment that is compatible with the vector
97 /// type of this operation. This involves extracting elements from the struct
98 /// and inserting them into an LLVM array. These extra data-movement
99 /// operations should be canonicalized away by the LLVM backend.
100 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
101  Type resultType, Value intrinsicResult,
102  RewriterBase &rewriter) {
103  MLIRContext *ctx = rewriter.getContext();
104  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
106  Type i32Ty = rewriter.getI32Type();
107  Type f32Ty = rewriter.getF32Type();
108  Type f64Ty = rewriter.getF64Type();
109  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
110  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
111  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
112  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
113  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
114 
115  auto makeConst = [&](int32_t index) -> Value {
116  return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
117  rewriter.getI32IntegerAttr(index));
118  };
119 
120  if (arrayType) {
121  SmallVector<Value, 4> elements;
122 
123  // The intrinsic returns 32-bit wide elements in a form which can be
124  // directly bitcasted and inserted into the result vector.
125  if (arrayType.getElementType() == f16x2Ty ||
126  arrayType.getElementType() == f32x1Ty) {
127  for (unsigned i = 0; i < structType.getBody().size(); i++) {
128  Value el =
129  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
130  el = rewriter.createOrFold<LLVM::BitcastOp>(
131  loc, arrayType.getElementType(), el);
132  elements.push_back(el);
133  }
134  }
135 
136  // The intrinsic returns i32, f64, and f32 values as individual scalars,
137  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
138  // need to extract them from the struct and pack them into the 64-bit wide
139  // rows of the vector result.
140  if (arrayType.getElementType() == i32x2Ty ||
141  arrayType.getElementType() == f64x2Ty ||
142  arrayType.getElementType() == f32x2Ty) {
143 
144  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145  Value vec =
146  rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType());
147  Value x1 =
148  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149  Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
150  i * 2 + 1);
151  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
152  x1, makeConst(0));
153  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
154  x2, makeConst(1));
155  elements.push_back(vec);
156  }
157  }
158 
159  // Create the final vectorized result.
160  Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType);
161  for (const auto &el : llvm::enumerate(elements)) {
162  result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
163  el.index());
164  }
165  return result;
166  }
167 
168  return intrinsicResult;
169 }
170 
171 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
172 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
173 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
174 /// scalars of certain types. This function helps unpack the `vector` arguments
175 /// and cast them to the types expected by `nvvm.mma.sync`.
177  Value operand,
178  NVVM::MMATypes operandPtxType) {
179  SmallVector<Value> result;
180  Type i32Ty = b.getI32Type();
181  Type f64Ty = b.getF64Type();
182  Type f32Ty = b.getF32Type();
183  Type i64Ty = b.getI64Type();
184  Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
185  Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
186  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
187  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
188 
189  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190  Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
191 
192  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
193  // scalar types.
194  if (arrayTy.getElementType() == i8x4Ty ||
195  arrayTy.getElementType() == i4x8Ty ||
196  (arrayTy.getElementType() == f32x1Ty &&
197  operandPtxType == NVVM::MMATypes::tf32)) {
198  result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
199  continue;
200  }
201 
202  // For some element types (i32, f32, f64), we need to unpack the inner
203  // vector/array type as well because the intrinsic expects individual
204  // scalars to be provided.
205  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207  innerArrayTy.getElementType() == f64Ty ||
208  innerArrayTy.getElementType() == f32Ty)) {
209  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210  idx < innerSize; idx++) {
211  result.push_back(b.create<LLVM::ExtractElementOp>(
212  toUse,
213  b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
214  }
215  continue;
216  }
217  result.push_back(toUse);
218  }
219  return result;
220 }
221 
222 /// Returns whether mbarrier object has shared memory address space.
223 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
224  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225  barrierType.getMemorySpace()));
226 }
227 
228 /// Returns the memory space attribute of the mbarrier object.
230  nvgpu::MBarrierGroupType barrierType) {
231  Attribute memorySpace = {};
232  if (isMbarrierShared(barrierType)) {
233  memorySpace =
234  IntegerAttr::get(IntegerType::get(context, 64),
235  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
236  }
237  return memorySpace;
238 }
239 
240 /// Returns memref type of the mbarrier object. The type is defined in the
241 /// MBarrierGroupType.
242 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
243  nvgpu::MBarrierGroupType barrierType) {
244  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
245  MemRefLayoutAttrInterface layout;
246  return MemRefType::get({barrierType.getNumBarriers()},
247  IntegerType::get(context, 64), layout, memorySpace);
248 }
249 
250 namespace {
251 
252 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
254 
255  LogicalResult
256  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
257  ConversionPatternRewriter &rewriter) const override {
258  MLIRContext *ctx = getContext();
259  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
260 
261  // The result type of ldmatrix will always be a struct of 32bit integer
262  // registers if more than one 32bit value is returned. Otherwise, the result
263  // is a single i32. The result type of the GPU operation is always a vector
264  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
265  // vector type of the result and always 32 bits long. We bitcast the result
266  // of the NVVM::LdMatrix to this vector type.
267  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268  if (!vectorResultType) {
269  return failure();
270  }
271  Type innerVectorType = LLVM::getFixedVectorType(
272  vectorResultType.getElementType(), vectorResultType.getDimSize(1));
273 
274  int64_t num32BitRegs = vectorResultType.getDimSize(0);
275 
276  Type ldMatrixResultType;
277  if (num32BitRegs > 1) {
278  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
279  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
280  } else {
281  ldMatrixResultType = rewriter.getI32Type();
282  }
283 
284  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285  Value srcPtr =
286  getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287  adaptor.getIndices(), rewriter);
288  Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289  ldMatrixResultType, srcPtr,
290  /*num=*/op.getNumTiles(),
291  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
292  : NVVM::MMALayout::row);
293 
294  // The ldmatrix operation returns either a single i32 value or a struct of
295  // i32 values. Here we unpack those values and cast them back to their
296  // actual vector type (still of width 32b) and repack them into a result
297  // struct.
298  Type finalResultType = typeConverter->convertType(vectorResultType);
299  Value result = b.create<LLVM::PoisonOp>(finalResultType);
300  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301  Value i32Register =
302  num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303  : ldMatrixResult;
304  Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305  result = b.create<LLVM::InsertValueOp>(result, casted, i);
306  }
307 
308  rewriter.replaceOp(op, result);
309  return success();
310  }
311 };
312 
313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314 /// enum).
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316  Type elType = getElementTypeOrSelf(t);
317  if (elType.isInteger(8))
318  return NVVM::MMATypes::s8;
319  if (elType.isInteger(4))
320  return NVVM::MMATypes::s4;
321  if (elType.isF16())
322  return NVVM::MMATypes::f16;
323  if (elType.isF64())
324  return NVVM::MMATypes::f64;
325  if (elType.isF32())
326  return NVVM::MMATypes::tf32;
327  return failure();
328 }
329 
330 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
332 
333  LogicalResult
334  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override {
336  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337  // Get the shapes of the MMAMatrix type being used. The shapes will
338  // choose which intrinsic this op will be lowered to.
339  VectorType aType = op.getMatrixA().getType();
340  VectorType bType = op.getMatrixA().getType();
341  VectorType cType = op.getMatrixC().getType();
342 
343  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344 
345  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347  if (aType.getElementType().isF32() && !tf32Enabled)
348  return failure();
349 
350  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351  if (failed(ptxTypeA))
352  return op->emitOpError("failed to deduce operand PTX types");
353  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354  if (failed(ptxTypeB))
355  return op->emitOpError("failed to deduce operand PTX types");
356  std::optional<NVVM::MMATypes> ptxTypeC =
357  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358  /*isAccumulator=*/true);
359  if (!ptxTypeC)
360  return op->emitError(
361  "could not infer the PTX type for the accumulator/result");
362 
363  // TODO: add an attribute to the op to customize this behavior.
364  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365  if (isa<IntegerType>(aType.getElementType()))
366  overflow = NVVM::MMAIntOverflow::satfinite;
367 
368  SmallVector<Value> matA =
369  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370  SmallVector<Value> matB =
371  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372  SmallVector<Value> matC =
373  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374 
375  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376  Type intrinsicResTy = inferIntrinsicResultType(
377  typeConverter->convertType(op->getResultTypes()[0]));
378  Value intrinsicResult = b.create<NVVM::MmaOp>(
379  intrinsicResTy, matA, matB, matC,
380  /*shape=*/gemmShape,
381  /*b1Op=*/std::nullopt,
382  /*intOverflow=*/overflow,
383  /*multiplicandPtxTypes=*/
384  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385  /*multiplicandLayouts=*/
386  std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387  NVVM::MMALayout::col});
388  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389  desiredRetTy, intrinsicResult,
390  rewriter));
391  return success();
392  }
393 };
394 
395 struct ConvertNVGPUToNVVMPass
396  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397  using Base::Base;
398 
399  void getDependentDialects(DialectRegistry &registry) const override {
400  registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401  arith::ArithDialect>();
402  }
403 
404  void runOnOperation() override {
407  LLVMTypeConverter converter(&getContext(), options);
408  IRRewriter rewriter(&getContext());
410  converter, [](gpu::AddressSpace space) -> unsigned {
411  switch (space) {
412  case gpu::AddressSpace::Global:
413  return static_cast<unsigned>(
415  case gpu::AddressSpace::Workgroup:
416  return static_cast<unsigned>(
418  case gpu::AddressSpace::Private:
419  return 0;
420  }
421  llvm_unreachable("unknown address space enum value");
422  return 0;
423  });
424  /// device-side async tokens cannot be materialized in nvvm. We just
425  /// convert them to a dummy i32 type in order to easily drop them during
426  /// conversion.
427  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
428  return converter.convertType(IntegerType::get(type.getContext(), 32));
429  });
430  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
431  Type elemType = type.getFragmented().getElementType();
432  int64_t sizeM = type.getFragmented().getDimSize(0);
433  int64_t sizeN = type.getFragmented().getDimSize(1);
434 
435  unsigned numMembers;
436  if (elemType.isF32() || elemType.isInteger(32))
437  numMembers = sizeN / 2;
438  else if (elemType.isF16())
439  numMembers = sizeN / 4;
440  else
441  llvm_unreachable("unsupported type for warpgroup accumulator");
442 
443  SmallVector<Type> innerStructBody;
444  for (unsigned i = 0; i < numMembers; i++)
445  innerStructBody.push_back(elemType);
446  auto innerStructType =
447  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
448 
449  SmallVector<Type> structBody;
450  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
451  structBody.push_back(innerStructType);
452 
453  auto convertedType =
454  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
455  return converter.convertType(convertedType);
456  });
457  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
458  return converter.convertType(IntegerType::get(type.getContext(), 64));
459  });
460  converter.addConversion(
461  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
462  return converter.convertType(IntegerType::get(type.getContext(), 64));
463  });
464  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
465  return converter.convertType(
466  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
467  });
468  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
469  return LLVM::LLVMPointerType::get(type.getContext());
470  });
473  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474  target.addLegalDialect<::mlir::arith::ArithDialect>();
475  target.addLegalDialect<::mlir::memref::MemRefDialect>();
476  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
478  converter, patterns, target);
479  if (failed(applyPartialConversion(getOperation(), target,
480  std::move(patterns))))
481  signalPassFailure();
482  }
483 };
484 
485 /// Returns the constraints for the sparse MMA inline assembly instruction.
486 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
487  unsigned matBSize,
488  unsigned matCSize) {
489  std::string str;
490  llvm::raw_string_ostream ss(str);
491  for (unsigned i = 0; i < matCSize; i++)
492  ss << "=r,";
493  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
494  ss << "r,";
495  // The final operand is for the sparsity metadata.
496  // The sparsity selector appears as direct literal.
497  ss << "r";
498  return str;
499 }
500 
501 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
502 /// the given parameters. Note that this function doesn't do any validation,
503 /// it's expected that the provided parameters correspond to a valid
504 /// instruction.
505 static std::string buildMmaSparseAsmString(
506  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
507  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
510  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511  return NVVM::stringifyMMATypes(ptxType);
512  };
513 
514  std::string asmStr;
515  llvm::raw_string_ostream ss(asmStr);
516  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
517  << shape[2] << ".row.col.";
518 
519  if (overflow)
520  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
521 
522  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
523  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
524  unsigned asmArgIdx = 0;
525 
526  // The operand string is structured into sections `{matC elements...},
527  // {matA elements...}, {matB elements...}, {matC elements}`.
528  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529  ss << "{";
530  for (unsigned i = 0; i < arrSize; i++)
531  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
532  ss << "},";
533  }
534  ss << "$" << asmArgIdx++ << ",";
535  assert(metaDataSelector <= 1);
536  ss << "0x" << metaDataSelector << ";";
537  return asmStr;
538 }
539 
540 /// Builds an inline assembly operation corresponding to the specified MMA
541 /// sparse sync operation.
542 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
543  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
544  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
545  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
546  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
547  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
548  Type intrinsicResultType) {
549  auto asmDialectAttr =
550  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
551 
552  const unsigned matASize = unpackedAData.size();
553  const unsigned matBSize = unpackedB.size();
554  const unsigned matCSize = unpackedC.size();
555 
556  std::string asmStr = buildMmaSparseAsmString(
557  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
558  ptxTypeD, overflow, metadataSelector);
559  std::string constraintStr =
560  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
561 
562  SmallVector<Value> asmVals;
563  asmVals.reserve(matASize + matBSize + matCSize + 1);
564  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
565  llvm::append_range(asmVals, args);
566  asmVals.push_back(indexData);
567 
568  return b.create<LLVM::InlineAsmOp>(
569  /*resultTypes=*/intrinsicResultType,
570  /*operands=*/asmVals,
571  /*asm_string=*/asmStr,
572  /*constraints=*/constraintStr,
573  /*has_side_effects=*/true,
574  /*is_align_stack=*/false,
575  /*asm_dialect=*/asmDialectAttr,
576  /*operand_attrs=*/ArrayAttr());
577 }
578 
579 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
580 struct NVGPUMmaSparseSyncLowering
581  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
583 
584  LogicalResult
585  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
586  ConversionPatternRewriter &rewriter) const override {
587  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
588  // Get the shapes of the MMAMatrix type being used. The shapes will
589  // choose which intrinsic this op will be lowered to.
590  VectorType aType = op.getMatrixA().getType();
591  VectorType bType = op.getMatrixB().getType();
592  VectorType cType = op.getMatrixC().getType();
593 
594  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595  if (failed(ptxTypeA))
596  return op->emitOpError("failed to deduce operand PTX types");
597  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598  if (failed(ptxTypeB))
599  return op->emitOpError("failed to deduce operand PTX types");
600  std::optional<NVVM::MMATypes> ptxTypeC =
601  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
602  /*isAccumulator=*/true);
603  if (!ptxTypeC)
604  return op->emitError(
605  "could not infer the PTX type for the accumulator/result");
606 
607  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
608  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609  if (aType.getElementType().isF32() && !tf32Enabled)
610  return failure();
611 
612  // TODO: add an attribute to the op to customize this behavior.
613  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614  if (isa<IntegerType>(aType.getElementType()))
615  overflow = NVVM::MMAIntOverflow::satfinite;
616 
617  SmallVector<Value> matA =
618  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
619  SmallVector<Value> matB =
620  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
621  SmallVector<Value> matC =
622  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
623 
624  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
625  Type intrinsicResTy = inferIntrinsicResultType(
626  typeConverter->convertType(op->getResultTypes()[0]));
627 
628  // Bitcast the sparse metadata from vector<2xf16> to an i32.
629  Value sparseMetadata = adaptor.getSparseMetadata();
630  if (sparseMetadata.getType() !=
631  LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
632  return op->emitOpError() << "Expected metadata type to be LLVM "
633  "VectorType of 2 i16 elements";
634  sparseMetadata =
635  b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
636 
637  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
638  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
639  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
640  intrinsicResTy);
641  if (failed(intrinsicResult))
642  return failure();
643 
644  assert((*intrinsicResult).getNumResults() == 1 &&
645  "expected inline asm op returns a single LLVM struct type");
646  rewriter.replaceOp(
647  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
648  (*intrinsicResult)->getResult(0), rewriter));
649  return success();
650  }
651 };
652 
653 struct NVGPUAsyncCopyLowering
654  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
656  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
657 
658  LogicalResult
659  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
660  ConversionPatternRewriter &rewriter) const override {
661  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
662  Location loc = op.getLoc();
663  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
664  Value dstPtr =
665  getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
666  adaptor.getDstIndices(), rewriter);
667  FailureOr<unsigned> dstAddressSpace =
668  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
669  if (failed(dstAddressSpace))
670  return rewriter.notifyMatchFailure(
671  loc, "destination memref address space not convertible to integer");
672 
673  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
674  FailureOr<unsigned> srcAddressSpace =
675  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
676  if (failed(srcAddressSpace))
677  return rewriter.notifyMatchFailure(
678  loc, "source memref address space not convertible to integer");
679 
680  Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
681  adaptor.getSrcIndices(), rewriter);
682  // Intrinsics takes a global pointer so we need an address space cast.
683  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
685  scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
686  int64_t dstElements = adaptor.getDstElements().getZExtValue();
687  int64_t sizeInBytes =
688  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
689  // When the optional SrcElements argument is *not* present, the regular
690  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
691  // memory) to fill DstElements number of elements in the destination
692  // (shared memory).
693  Value srcBytes = adaptor.getSrcElements();
694  if (srcBytes) {
695  // When the optional SrcElements argument is present, the source (global
696  // memory) of CpAsyncOp is read only for SrcElements number of elements.
697  // The rest of the DstElements in the destination (shared memory) are
698  // filled with zeros.
699  Value c3I32 =
700  b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
701  Value bitwidth = b.create<LLVM::ConstantOp>(
702  b.getI32Type(),
703  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
704  Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
705  srcBytes = b.create<LLVM::LShrOp>(
706  b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
707  }
708  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
709  // 16 dst bytes.
710  NVVM::LoadCacheModifierKind cacheModifier =
711  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
712  ? NVVM::LoadCacheModifierKind::CG
713  : NVVM::LoadCacheModifierKind::CA;
714 
715  b.create<NVVM::CpAsyncOp>(
716  dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
717  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
718  srcBytes);
719 
720  // Drop the result token.
721  Value zero = b.create<LLVM::ConstantOp>(
722  IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
723  rewriter.replaceOp(op, zero);
724  return success();
725  }
726 };
727 
728 struct NVGPUAsyncCreateGroupLowering
729  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
731  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
732 
733  LogicalResult
734  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
735  ConversionPatternRewriter &rewriter) const override {
736  rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
737  // Drop the result token.
738  Value zero = rewriter.create<LLVM::ConstantOp>(
739  op->getLoc(), IntegerType::get(op.getContext(), 32),
740  rewriter.getI32IntegerAttr(0));
741  rewriter.replaceOp(op, zero);
742  return success();
743  }
744 };
745 
746 struct NVGPUAsyncWaitLowering
747  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
749  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
750 
751  LogicalResult
752  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
753  ConversionPatternRewriter &rewriter) const override {
754  // If numGroup is not present pick 0 as a conservative correct value.
755  int32_t numGroups = adaptor.getNumGroups().value_or(0);
756  rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
757  rewriter.eraseOp(op);
758  return success();
759  }
760 };
761 
762 /// Creates mbarrier object in shared memory
763 struct NVGPUMBarrierCreateLowering
764  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
766 
767  template <typename moduleT>
768  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
769  Operation *funcOp, moduleT moduleOp,
770  MemRefType barrierType) const {
771  SymbolTable symbolTable(moduleOp);
772  OpBuilder::InsertionGuard guard(rewriter);
773  rewriter.setInsertionPoint(&moduleOp.front());
774  auto global = rewriter.create<memref::GlobalOp>(
775  funcOp->getLoc(), "__mbarrier",
776  /*sym_visibility=*/rewriter.getStringAttr("private"),
777  /*type=*/barrierType,
778  /*initial_value=*/ElementsAttr(),
779  /*constant=*/false,
780  /*alignment=*/rewriter.getI64IntegerAttr(8));
781  symbolTable.insert(global);
782  return global;
783  }
784 
785  LogicalResult
786  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
787  ConversionPatternRewriter &rewriter) const override {
788  Operation *funcOp = op->getParentOp();
789  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
790  rewriter.getContext(), op.getBarriers().getType());
791 
792  memref::GlobalOp global;
793  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
794  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
795  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
796  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
797 
798  rewriter.setInsertionPoint(op);
799  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
800  global.getName());
801  return success();
802  }
803 };
804 
805 /// Base class for lowering mbarrier operations to nvvm intrinsics.
806 template <typename SourceOp>
807 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
808 public:
810  /// Returns the base pointer of the mbarrier object.
811  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
812  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
813  Value mbarId,
814  ConversionPatternRewriter &rewriter) const {
815  MemRefType mbarrierMemrefType =
816  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
818  b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
819  }
820 };
821 
822 struct NVGPUMBarrierGetLowering
823  : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
824  using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
825 
826  LogicalResult
827  matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
828  ConversionPatternRewriter &rewriter) const override {
829  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
830  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
831  rewriter.setInsertionPoint(op);
832  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
833  adaptor.getMbarId(), rewriter);
834  Type resType = op.getMbarrierPointer().getType();
835  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
836  return success();
837  }
838 };
839 
840 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
841 struct NVGPUMBarrierInitLowering
842  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
843  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
844 
845  LogicalResult
846  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
847  ConversionPatternRewriter &rewriter) const override {
848  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
849  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
850  rewriter.setInsertionPoint(op);
851  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
852  adaptor.getMbarId(), rewriter);
853  Value count = truncToI32(b, adaptor.getCount());
854  if (isMbarrierShared(mbarrierType)) {
855  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
856  op, barrier, count, adaptor.getPredicate());
857  } else {
858  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
859  adaptor.getPredicate());
860  }
861  return success();
862  }
863 };
864 
865 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
866 struct NVGPUMBarrierArriveLowering
867  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
868  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
869  LogicalResult
870  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
871  ConversionPatternRewriter &rewriter) const override {
872  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
873  Value barrier =
874  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
875  adaptor.getMbarId(), rewriter);
876  Type tokenType = getTypeConverter()->convertType(
877  nvgpu::MBarrierTokenType::get(op->getContext()));
878  if (isMbarrierShared(op.getBarriers().getType())) {
879  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
880  barrier);
881  } else {
882  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
883  barrier);
884  }
885  return success();
886  }
887 };
888 
889 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
890 /// `nvvm.mbarrier.arrive.nocomplete`
891 struct NVGPUMBarrierArriveNoCompleteLowering
892  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
893  using MBarrierBasePattern<
894  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
895  LogicalResult
896  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
897  ConversionPatternRewriter &rewriter) const override {
898  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
899  Value barrier =
900  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
901  adaptor.getMbarId(), rewriter);
902  Type tokenType = getTypeConverter()->convertType(
903  nvgpu::MBarrierTokenType::get(op->getContext()));
904  Value count = truncToI32(b, adaptor.getCount());
905  if (isMbarrierShared(op.getBarriers().getType())) {
906  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
907  op, tokenType, barrier, count);
908  } else {
909  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
910  op, tokenType, barrier, count);
911  }
912  return success();
913  }
914 };
915 
916 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
917 struct NVGPUMBarrierTestWaitLowering
918  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
919  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
920  LogicalResult
921  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
922  ConversionPatternRewriter &rewriter) const override {
923  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
924  Value barrier =
925  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
926  adaptor.getMbarId(), rewriter);
927  Type retType = rewriter.getI1Type();
928  if (isMbarrierShared(op.getBarriers().getType())) {
929  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
930  op, retType, barrier, adaptor.getToken());
931  } else {
932  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
933  op, retType, barrier, adaptor.getToken());
934  }
935  return success();
936  }
937 };
938 
939 struct NVGPUMBarrierArriveExpectTxLowering
940  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
941  using MBarrierBasePattern<
942  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
943  LogicalResult
944  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
945  ConversionPatternRewriter &rewriter) const override {
946  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
947  Value barrier =
948  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
949  adaptor.getMbarId(), rewriter);
950  Value txcount = truncToI32(b, adaptor.getTxcount());
951 
952  if (isMbarrierShared(op.getBarriers().getType())) {
953  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
954  op, barrier, txcount, adaptor.getPredicate());
955  return success();
956  }
957 
958  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
959  op, barrier, txcount, adaptor.getPredicate());
960  return success();
961  }
962 };
963 
964 struct NVGPUMBarrierTryWaitParityLowering
965  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
966  using MBarrierBasePattern<
967  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
968  LogicalResult
969  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
970  ConversionPatternRewriter &rewriter) const override {
971  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
972  Value barrier =
973  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
974  adaptor.getMbarId(), rewriter);
975  Value ticks = truncToI32(b, adaptor.getTicks());
976  Value phase =
977  b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
978 
979  if (isMbarrierShared(op.getBarriers().getType())) {
980  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
981  op, barrier, phase, ticks);
982  return success();
983  }
984 
985  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
986  phase, ticks);
987  return success();
988  }
989 };
990 
991 struct NVGPUTmaAsyncLoadOpLowering
992  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
993  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
994  LogicalResult
995  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
996  ConversionPatternRewriter &rewriter) const override {
997  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
998  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
999  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1000  adaptor.getDst(), {}, rewriter);
1001  Value barrier =
1002  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1003  adaptor.getMbarId(), rewriter);
1004 
1005  SmallVector<Value> coords = adaptor.getCoordinates();
1006  for (auto [index, value] : llvm::enumerate(coords)) {
1007  coords[index] = truncToI32(b, value);
1008  }
1009  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
1010  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1011  ValueRange{}, adaptor.getMulticastMask(), Value{},
1012  adaptor.getPredicate());
1013  return success();
1014  }
1015 };
1016 
1017 struct NVGPUTmaAsyncStoreOpLowering
1018  : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1019  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1020  LogicalResult
1021  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1022  ConversionPatternRewriter &rewriter) const override {
1023  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1024  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1025  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1026  adaptor.getSrc(), {}, rewriter);
1027  SmallVector<Value> coords = adaptor.getCoordinates();
1028  for (auto [index, value] : llvm::enumerate(coords)) {
1029  coords[index] = truncToI32(b, value);
1030  }
1031 
1032  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1033  op, adaptor.getTensorMapDescriptor(), dest, coords,
1034  adaptor.getPredicate());
1035  return success();
1036  }
1037 };
1038 
1039 struct NVGPUGenerateWarpgroupDescriptorLowering
1040  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1041  using ConvertOpToLLVMPattern<
1042  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1043 
1044  LogicalResult
1045  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1046  ConversionPatternRewriter &rewriter) const override {
1047 
1048  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1049 
1050  nvgpu::TensorMapSwizzleKind swizzleKind =
1051  op.getTensorMap().getType().getSwizzle();
1052 
1053  unsigned layout =
1054  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1055  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1056  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1057  : 1;
1058  unsigned swizzle =
1059  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1060  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1061  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1062  : 0;
1063 
1064  auto ti64 = b.getIntegerType(64);
1065  auto makeConst = [&](uint64_t index) -> Value {
1066  return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1067  };
1068  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1069  return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1070  };
1071  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1072  return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1073  };
1074  auto insertBit = [&](Value desc, Value val, int startBit) {
1075  return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1076  };
1077 
1078  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1079  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1080  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1081  uint64_t offsetVal = 0;
1082 
1083  Value strideDim = makeConst(strideDimVal);
1084  Value leadDim = makeConst(leadDimVal);
1085 
1086  Value baseAddr = getStridedElementPtr(
1087  op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1088  adaptor.getTensor(), {}, rewriter);
1089  Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1090  // Just use 14 bits for base address
1091  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1092 
1093  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1094  startLeadBit = 16, startBaseAddrBit = 0;
1095  Value dsc = makeConst(0);
1096  // // [62,64) swizzle type
1097  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1098  // // [49,52) base_offset
1099  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1100  // // [32,46) stride
1101  dsc = insertBit(dsc, strideDim, startStrideBit);
1102  // // [16,30) leading dimension
1103  dsc = insertBit(dsc, leadDim, startLeadBit);
1104  // // [0,14) start_address
1105  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1106 
1107  LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1108  << "leading_off:" << leadDimVal << "\t"
1109  << "stride_off :" << strideDimVal << "\t"
1110  << "base_offset:" << offsetVal << "\t"
1111  << "layout_type:" << swizzle << " ("
1112  << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1113  << ")\n start_addr : " << baseAddr << "\n");
1114 
1115  rewriter.replaceOp(op, dsc);
1116  return success();
1117  }
1118 };
1119 
1120 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1121  return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1122  b.getI32IntegerAttr(index));
1123 }
1124 
1125 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1126 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1127  // Enum is from CUDA driver API
1128  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1129  enum CUtensorMapDataTypeEnum {
1130  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1131  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1132  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1133  CU_TENSOR_MAP_DATA_TYPE_INT32,
1134  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1135  CU_TENSOR_MAP_DATA_TYPE_INT64,
1136  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1137  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1138  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1139  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1140  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1141  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1142  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1143  };
1144 
1145  if (type.isUnsignedInteger(8))
1146  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1147  if (type.isUnsignedInteger(16))
1148  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1149  if (type.isUnsignedInteger(32))
1150  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1151  if (type.isUnsignedInteger(64))
1152  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1153  if (type.isSignlessInteger(32))
1154  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1155  if (type.isSignlessInteger(64))
1156  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1157  if (type.isF16())
1158  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1159  if (type.isF32())
1160  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1161  if (type.isF64())
1162  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1163  if (type.isBF16())
1164  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1165 
1166  llvm_unreachable("Not supported data type");
1167 }
1168 
1169 struct NVGPUTmaCreateDescriptorOpLowering
1170  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1171  using ConvertOpToLLVMPattern<
1172  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1173  LogicalResult
1174  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1175  ConversionPatternRewriter &rewriter) const override {
1176  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1177  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1178  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1179 
1180  Value tensorElementType =
1181  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1182  auto promotedOperands = getTypeConverter()->promoteOperands(
1183  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1184 
1185  Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1186  makeI64Const(b, 5));
1187  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1188  Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1189  boxArrayPtr, makeI64Const(b, index));
1190  b.create<LLVM::StoreOp>(value, gep);
1191  }
1192 
1193  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1194  // Set Arguments for the function call
1195  SmallVector<Value> arguments;
1196  arguments.push_back(promotedOperands[0]); // rank
1197  arguments.push_back(promotedOperands[1]); // descriptor
1198  arguments.push_back(tensorElementType); // data type
1199  arguments.push_back(
1200  makeI64Const(b, (int)desc.getInterleave())); // interleave
1201  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1202  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1203  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1204  arguments.push_back(boxArrayPtr); // box dimensions
1205 
1206  // Set data types of the arguments
1207  SmallVector<Type> argTypes = {
1208  llvmInt64Type, /* int64_t tensorRank */
1209  llvmPointerType, /* ptr */
1210  llvmInt64Type, /* int64_t */
1211  llvmInt64Type, /* int64_t */
1212  llvmInt64Type, /* int64_t */
1213  llvmInt64Type, /* int64_t */
1214  llvmInt64Type, /* int64_t */
1215  llvmPointerType /* ptr */
1216  };
1217  FunctionCallBuilder hostRegisterCallBuilder = {
1218  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1219  Value tensorMap =
1220  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1221 
1222  rewriter.replaceOp(op, tensorMap);
1223  return success();
1224  }
1225 };
1226 
1227 struct NVGPUWarpgroupMmaOpLowering
1228  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1230 
1231  /// This is a helper class to generate required NVVM Ops for warp-group level
1232  /// matrix multiplication.
1233  /// When the given GEMM shape is larger than the shape of
1234  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1235  /// Op(s), group and execute them asynchronously. The class also handles
1236  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1237  /// create descriptors for each instruction.
1238  ///
1239  /// For example this is the case when the shape of GEMM is 128x128x128
1240  ///
1241  /// nvvm.wgmma.fence.aligned
1242  ///
1243  /// nvvm.wgmma.mma.async descA, descB
1244  /// iterate(descA, descB)
1245  /// nvvm.wgmma.mma.async descA, descB
1246  /// [6x times more]
1247  ///
1248  /// nvvm.wgmma.group.sync.aligned
1249  /// nvvm.wgmma.wait.group.sync [groupId]
1250  ///
1251  class WarpgroupGemm {
1252  nvgpu::WarpgroupMmaOp op;
1254  OpAdaptor adaptor;
1255 
1256  // Entire shape of the given Op
1257  int64_t totalM, totalN, totalK;
1258 
1259  // Shape of one wgmma instruction
1260  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1261 
1262  // Iteration counts for GEMM
1263  int iterationM = 0, iterationN = 0, iterationK = 0;
1264 
1265  /// The function returns the shape of wgmma instruction that is defined in
1266  /// PTX programming guide.
1267  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1268  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1269  wgmmaM = 64;
1270  wgmmaN = sizeN;
1271  if (inputElemType.isTF32()) {
1272  wgmmaK = 8;
1273  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1274  wgmmaK = 16;
1275  } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1276  inputElemType.isInteger(16)) {
1277  wgmmaK = 32;
1278  } else if (inputElemType.isInteger(1)) {
1279  wgmmaK = 256;
1280  } else {
1281  llvm_unreachable("msg: not supported K shape");
1282  }
1283  LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1284  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1285  }
1286 
1287  /// Generates WGMMATypesAttr from MLIR Type
1288  NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1289  bool useF32 = false) const {
1290  auto getWgmmaType = [=](Type elemType) {
1291  if (elemType.isF32() || elemType.isTF32())
1292  return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1293  if (elemType.isF16())
1294  return NVVM::WGMMATypes::f16;
1295  if (elemType.isBF16())
1296  return NVVM::WGMMATypes::bf16;
1297  if (isa<Float8E4M3FNType>(elemType))
1298  return NVVM::WGMMATypes::e4m3;
1299  if (isa<Float8E5M2Type>(elemType))
1300  return NVVM::WGMMATypes::e5m2;
1301  if (elemType.isInteger(1))
1302  return NVVM::WGMMATypes::b1;
1303  if (elemType.isInteger(8))
1304  return NVVM::WGMMATypes::s8;
1305  if (elemType.isUnsignedInteger(8))
1306  return NVVM::WGMMATypes::u8;
1307  if (elemType.isInteger(32))
1308  return NVVM::WGMMATypes::s32;
1309  llvm_unreachable("unsupported type");
1310  };
1311  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1312  }
1313 
1314  /// Generates layout attribute for the input matrix for wgmma instruction
1315  NVVM::MMALayoutAttr
1316  generateWgmmaLayout(std::optional<bool> transpose) const {
1317  if (transpose.value_or(false))
1318  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1319  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1320  }
1321 
1322  /// Generates shape attribute for wgmma instruction
1323  NVVM::MMAShapeAttr generateWgmmaShape() const {
1324  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1325  }
1326 
1327  /// Generates scale attributes of output matrix for wgmma instruction
1328  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1329  return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1330  NVVM::WGMMAScaleOut::one);
1331  }
1332  /// Generates scale attributes of input matrix for wgmma instruction
1333  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1334  return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1335  NVVM::WGMMAScaleIn::one);
1336  }
1337 
1338  /// Basic function to generate Add
1339  Value makeAdd(Value lhs, Value rhs) {
1340  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1341  };
1342 
1343  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1344  /// Currently, it only handles row-major.
1345  ///
1346  /// It moves the pointer like below for [128][64] size:
1347  /// +2 +4 +6
1348  /// ↓ ↓ ↓
1349  /// descA ---> +--+--+--+--+
1350  /// |->|->|->|->|
1351  /// | | | | |
1352  /// | | | | |
1353  /// | | | | |
1354  /// descA+512---> +-----------+
1355  /// | | | | |
1356  /// | | | | |
1357  /// | | | | |
1358  /// | | | | |
1359  /// +-----------+
1360  ///
1361  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1362  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1363  Type elemA = matrixTypeA.getElementType();
1364  int byte = elemA.getIntOrFloatBitWidth() / 8;
1365  int tileShapeA = matrixTypeA.getDimSize(1);
1366  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1367  incrementVal = incrementVal >> exclude4LSB;
1368  LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1369  << "] [wgmma descriptors] Descriptor A + "
1370  << incrementVal << " | \t ");
1371  if (!incrementVal)
1372  return desc;
1373  return makeAdd(desc, makeI64Const(b, incrementVal));
1374  }
1375 
1376  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1377  /// Currently, it only handles column-major.
1378  ///
1379  /// It moves the pointer like below for [128][64] size:
1380  /// descB ---> +--+--+--+--+--+--+--+--+
1381  /// |↓ | | | | | | | |
1382  /// |↓ | | | | | | | |
1383  /// |↓ | | | | | | | |
1384  /// |↓ | | | | | | | |
1385  /// +--+--+--+--+--+--+--+--+
1386  ///
1387  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1388  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1389  Type elemB = matrixTypeB.getElementType();
1390  int byte = elemB.getIntOrFloatBitWidth() / 8;
1391  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1392  incrementVal = incrementVal >> exclude4LSB;
1393  LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1394  if (!incrementVal)
1395  return desc;
1396  return makeAdd(desc, makeI64Const(b, incrementVal));
1397  }
1398 
1399  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1400  /// descriptors and arranges them based on induction variables: i, j, and k.
1401  Value generateWgmma(int i, int j, int k, Value matrixC) {
1402  LLVM_DEBUG(DBGS() << "\t wgmma."
1403  << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1404  << "(A[" << (iterationM * wgmmaM) << ":"
1405  << (iterationM * wgmmaM) + wgmmaM << "]["
1406  << (iterationK * wgmmaK) << ":"
1407  << (iterationK * wgmmaK + wgmmaK) << "] * "
1408  << " B[" << (iterationK * wgmmaK) << ":"
1409  << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1410  << wgmmaN << "])\n");
1411 
1412  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1413  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1414 
1415  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1416  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1417 
1418  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1419  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1420 
1421  Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1422  NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1423 
1424  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1425  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1426  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1427  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1428  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1429 
1430  auto overflow = NVVM::MMAIntOverflowAttr::get(
1431  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1432 
1433  return b.create<NVVM::WgmmaMmaAsyncOp>(
1434  matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1435  itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1436  overflow);
1437  }
1438 
1439  /// Generates multiple wgmma instructions to complete the given GEMM shape
1440  Value generateWgmmaGroup() {
1441  Value wgmmaResult =
1442  b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType());
1443 
1444  // Perform GEMM
1445  SmallVector<Value> wgmmaResults;
1446  for (int i = 0; i < iterationM; ++i) {
1447  Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1448  for (int j = 0; j < iterationN; ++j)
1449  for (int k = 0; k < iterationK; ++k)
1450  matrixC = generateWgmma(i, j, k, matrixC);
1451  wgmmaResults.push_back(matrixC);
1452  }
1453  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1454  wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1455  wgmmaResult, matrix, idx);
1456  }
1457  return wgmmaResult;
1458  }
1459 
1460  public:
1461  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1462  OpAdaptor adaptor)
1463  : op(op), b(b), adaptor(adaptor) {
1464  // Find the entire GEMM Shape
1465  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1466  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1467  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1468  LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1469  << "] += A[" << totalM << "][" << totalK << "] * B["
1470  << totalK << "][" << totalN << "] ---===\n");
1471 
1472  // Find the shape for one wgmma instruction
1473  findWgmmaShape(
1474  totalM, totalN,
1475  op.getDescriptorA().getType().getTensor().getElementType());
1476 
1477  // Iterations counts to complete the given shape with wgmma shape
1478  iterationM = totalM / wgmmaM;
1479  iterationN = totalN / wgmmaN;
1480  iterationK = totalK / wgmmaK;
1481  }
1482 
1483  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1484  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1485  /// instructions and group synchronization, as well as waiting
1486  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1487  /// (WgmmaWaitGroupSyncOp) after the instructions.
1488  Value generateWarpgroupMma() {
1489  b.create<NVVM::WgmmaFenceAlignedOp>();
1490  Value wgmmaResult = generateWgmmaGroup();
1491  b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1492  b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1493  return wgmmaResult;
1494  }
1495  };
1496  LogicalResult
1497  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1498  ConversionPatternRewriter &rewriter) const override {
1499  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1500 
1501  // Step 1. Build a helper class
1502  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1503 
1504  // Step 2. Get the entire GEMM Shape
1505  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1506 
1507  // Step 3. Replace fragmented result struct with the op results
1508  rewriter.replaceOp(op, wgmmaResult);
1509  return success();
1510  }
1511 };
1512 
1513 struct NVGPUWarpgroupMmaStoreOpLowering
1514  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1515  using ConvertOpToLLVMPattern<
1516  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1517 
1518  /// This function stores a fragmented register matrix owned by a warp group
1519  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1520  /// of a struct.
1521  /// Here is what each threads (T) holds, each `d` is struct value with a
1522  /// number.
1523  ///
1524  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1525  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1526  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1527  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1528  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1529  ///
1530  /// Matrix-D:
1531  /// +______________________________________________________________________+
1532  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1533  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1534  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1535  /// ..| .........|.........|.........|.........|........|...........|........|
1536  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1537  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1538  /// ..| .........|.........|.........|.........|........|...........|........|
1539  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1540  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1541  /// ..| .........|.........|.........|.........|........|...........|........|
1542  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1543  /// ..| .........|.........|.........|.........|........|...........|........|
1544  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1545  /// ..| .........|.........|.........|.........|........|...........|........|
1546  /// +______________________________________________________________________+
1547  ///
1548  /// \param rewriter: The pattern rewriter.
1549  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1550  /// matrix). It is holded by a thread and a struct with 64 elements.
1551  /// \param dstMemref: The memref where the registers will be stored.
1552  /// \param offset: the offset within the memref where the registers will be
1553  /// stored.
1554  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1555  TypedValue<MemRefType> dstMemref,
1556  int offset) const {
1557  Type i32 = b.getI32Type();
1558 
1559  auto makeConst = [&](int32_t index) -> Value {
1560  return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1561  };
1562  Value c1 = makeConst(1);
1563  Value c2 = makeConst(2);
1564  Value c4 = makeConst(4);
1565  Value c8 = makeConst(8);
1566  Value c16 = makeConst(16);
1567  Value warpSize = makeConst(kWarpSize);
1568 
1569  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1570  return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1571  };
1572  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1573  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1574  };
1575 
1576  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1578  Type it = b.getIndexType();
1579  Value idx = b.create<arith::IndexCastOp>(it, x);
1580  Value idy0 = b.create<arith::IndexCastOp>(it, y);
1581  Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1582  Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1583  Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1584  b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1585  b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1586  };
1587 
1588  Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1589  Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1590  Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1591  Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1592  Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1593 
1594  Value tj = makeMul(lane4modId, c2);
1595  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1596  if (offset)
1597  ti = makeAdd(ti, makeConst(offset));
1598 
1599  auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1600 
1601  // Number of 32-bit registers owns per thread
1602  constexpr unsigned numAdjacentRegisters = 2;
1603  // Number of 8x8 matrices one below another per warp
1604  constexpr unsigned numStackedMatrices = 2;
1605 
1606  size_t storeCount = (structType.getBody().size() /
1607  (numStackedMatrices * numAdjacentRegisters));
1608 
1609  for (size_t i = 0; i < numStackedMatrices; ++i) {
1610  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1611  for (size_t j = 0; j < storeCount; ++j) {
1612  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1613  size_t structIndex = (i * numAdjacentRegisters) +
1614  (j * (numStackedMatrices * numAdjacentRegisters));
1615  makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1616  }
1617  }
1618  }
1619 
1620  LogicalResult
1621  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1622  ConversionPatternRewriter &rewriter) const override {
1623  int offset = 0;
1624  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1625  Value matriDValue = adaptor.getMatrixD();
1626  auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1627  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1628  auto structType = cast<LLVM::LLVMStructType>(matrixD);
1629  Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1630  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1631  offset += structType.getBody().size();
1632  }
1633  rewriter.eraseOp(op);
1634  return success();
1635  }
1636 };
1637 
1638 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1639  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1640  using ConvertOpToLLVMPattern<
1641  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1642  LogicalResult
1643  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1644  ConversionPatternRewriter &rewriter) const override {
1645  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1646  LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1647  getTypeConverter()->convertType(op.getMatrixC().getType()));
1648  Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1649  .getBody()
1650  .front();
1651  Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1652  Value packStruct = b.create<LLVM::PoisonOp>(packStructType);
1653  SmallVector<Value> innerStructs;
1654  // Unpack the structs and set all values to zero
1655  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1656  auto structType = cast<LLVM::LLVMStructType>(s);
1657  Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1658  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1659  structValue = b.create<LLVM::InsertValueOp>(
1660  structType, structValue, zero, ArrayRef<int64_t>({i}));
1661  }
1662  innerStructs.push_back(structValue);
1663  }
1664  // Pack the inner structs into a single struct
1665  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1666  packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1667  packStruct, matrix, idx);
1668  }
1669  rewriter.replaceOp(op, packStruct);
1670  return success();
1671  }
1672 };
1673 
1674 struct NVGPUTmaFenceOpLowering
1675  : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1677  LogicalResult
1678  matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1679  ConversionPatternRewriter &rewriter) const override {
1680  MLIRContext *ctx = op.getContext();
1681  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1682  auto i32Ty = b.getI32Type();
1683  Value tensormapSize =
1684  b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128));
1685 
1686  auto memscope =
1687  NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1688 
1689  rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1690  op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1691 
1692  return success();
1693  }
1694 };
1695 
1696 struct NVGPUTmaPrefetchOpLowering
1697  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1699  LogicalResult
1700  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1701  ConversionPatternRewriter &rewriter) const override {
1702  rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1703  op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1704  return success();
1705  }
1706 };
1707 
1708 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1710  LogicalResult
1711  matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1712  ConversionPatternRewriter &rewriter) const override {
1713  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1714  auto i64Ty = b.getI64Type();
1715  auto f32Ty = b.getF32Type();
1716  VectorType inTy = op.getIn().getType();
1717  // apply rcp.approx.ftz.f on each element in vector.
1718  auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1719  Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy);
1720  int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1721  for (int i = 0; i < numElems; i++) {
1722  Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
1723  Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
1724  Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1725  ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1726  }
1727  return ret1DVec;
1728  };
1729  if (inTy.getRank() == 1) {
1730  rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1731  return success();
1732  }
1734  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1735  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1736  OpAdaptor adaptor(operands);
1737  return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1738  },
1739  rewriter);
1740  }
1741 };
1742 } // namespace
1743 
1745  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1746  patterns.add<
1747  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1748  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1749  NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1750  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1751  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1752  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1753  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1754  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1755  NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1756  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1757  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1758  NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1759  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1760  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1761  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1762  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1763  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1764  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1765  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1766  NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1767 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:51
#define DBGSE()
Definition: NVGPUToNVVM.cpp:36
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:61
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:47
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
#define DBGS()
Definition: NVGPUToNVVM.cpp:35
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:61
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
FloatType getF16Type()
Definition: Builders.cpp:39
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
IntegerType getI8Type()
Definition: Builders.cpp:59
FloatType getF64Type()
Definition: Builders.cpp:45
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:148
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:784
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isTF32() const
Definition: Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:961
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.