MLIR  22.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 
36 namespace mlir {
37 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38 #include "mlir/Conversion/Passes.h.inc"
39 } // namespace mlir
40 
41 using namespace mlir;
42 
43 /// Number of bits that needs to be excluded when building matrix descriptor for
44 /// wgmma operations.
45 constexpr int exclude4LSB = 4;
46 
47 /// GPU has 32 bit registers, this function truncates values when larger width
48 /// is not needed.
50  Type type = value.getType();
51  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
52  if (type.getIntOrFloatBitWidth() <= 32)
53  return value;
54  return LLVM::TruncOp::create(b, b.getI32Type(), value);
55 }
56 
57 /// Returns the type for the intrinsic given the vectorResultType of the
58 /// `gpu.mma.sync` operation.
59 static Type inferIntrinsicResultType(Type vectorResultType) {
60  MLIRContext *ctx = vectorResultType.getContext();
61  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62  auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63  auto i32Ty = IntegerType::get(ctx, 32);
64  auto i32x2Ty = VectorType::get(2, i32Ty);
65  Type f64Ty = Float64Type::get(ctx);
66  Type f64x2Ty = VectorType::get(2, f64Ty);
67  Type f32Ty = Float32Type::get(ctx);
68  Type f32x2Ty = VectorType::get(2, f32Ty);
69  if (a.getElementType() == f16x2Ty) {
70  return LLVM::LLVMStructType::getLiteral(
71  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
72  }
73  if (a.getElementType() == i32x2Ty) {
74  return LLVM::LLVMStructType::getLiteral(
75  ctx,
76  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
77  }
78  if (a.getElementType() == f64x2Ty) {
79  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
80  }
81  if (a.getElementType() == f32x2Ty) {
82  return LLVM::LLVMStructType::getLiteral(
83  ctx,
84  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
85  }
86  if (a.getElementType() == VectorType::get(1, f32Ty)) {
87  return LLVM::LLVMStructType::getLiteral(
88  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
89  }
90  return vectorResultType;
91 }
92 
93 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
94 /// always an LLVM struct) into a fragment that is compatible with the vector
95 /// type of this operation. This involves extracting elements from the struct
96 /// and inserting them into an LLVM array. These extra data-movement
97 /// operations should be canonicalized away by the LLVM backend.
98 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
99  Type resultType, Value intrinsicResult,
100  RewriterBase &rewriter) {
101  MLIRContext *ctx = rewriter.getContext();
102  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
104  Type i32Ty = rewriter.getI32Type();
105  Type f32Ty = rewriter.getF32Type();
106  Type f64Ty = rewriter.getF64Type();
107  Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
108  Type i32x2Ty = VectorType::get(2, i32Ty);
109  Type f64x2Ty = VectorType::get(2, f64Ty);
110  Type f32x2Ty = VectorType::get(2, f32Ty);
111  Type f32x1Ty = VectorType::get(1, f32Ty);
112 
113  auto makeConst = [&](int32_t index) -> Value {
114  return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
115  rewriter.getI32IntegerAttr(index));
116  };
117 
118  if (arrayType) {
119  SmallVector<Value, 4> elements;
120 
121  // The intrinsic returns 32-bit wide elements in a form which can be
122  // directly bitcasted and inserted into the result vector.
123  if (arrayType.getElementType() == f16x2Ty ||
124  arrayType.getElementType() == f32x1Ty) {
125  for (unsigned i = 0; i < structType.getBody().size(); i++) {
126  Value el =
127  LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
128  el = rewriter.createOrFold<LLVM::BitcastOp>(
129  loc, arrayType.getElementType(), el);
130  elements.push_back(el);
131  }
132  }
133 
134  // The intrinsic returns i32, f64, and f32 values as individual scalars,
135  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
136  // need to extract them from the struct and pack them into the 64-bit wide
137  // rows of the vector result.
138  if (arrayType.getElementType() == i32x2Ty ||
139  arrayType.getElementType() == f64x2Ty ||
140  arrayType.getElementType() == f32x2Ty) {
141 
142  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
143  Value vec =
144  LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
145  Value x1 =
146  LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147  Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
148  i * 2 + 1);
149  vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
150  x1, makeConst(0));
151  vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
152  x2, makeConst(1));
153  elements.push_back(vec);
154  }
155  }
156 
157  // Create the final vectorized result.
158  Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159  for (const auto &el : llvm::enumerate(elements)) {
160  result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
161  el.index());
162  }
163  return result;
164  }
165 
166  return intrinsicResult;
167 }
168 
169 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
170 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
171 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
172 /// scalars of certain types. This function helps unpack the `vector` arguments
173 /// and cast them to the types expected by `nvvm.mma.sync`.
175  Value operand,
176  NVVM::MMATypes operandPtxType) {
177  SmallVector<Value> result;
178  Type i32Ty = b.getI32Type();
179  Type f64Ty = b.getF64Type();
180  Type f32Ty = b.getF32Type();
181  Type i64Ty = b.getI64Type();
182  Type i8x4Ty = VectorType::get(4, b.getI8Type());
183  Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
184  Type f32x1Ty = VectorType::get(1, f32Ty);
185  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
186 
187  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188  Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
189 
190  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
191  // scalar types.
192  if (arrayTy.getElementType() == i8x4Ty ||
193  arrayTy.getElementType() == i4x8Ty ||
194  (arrayTy.getElementType() == f32x1Ty &&
195  operandPtxType == NVVM::MMATypes::tf32)) {
196  result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
197  continue;
198  }
199 
200  // For some element types (i32, f32, f64), we need to unpack the inner
201  // vector/array type as well because the intrinsic expects individual
202  // scalars to be provided.
203  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205  innerArrayTy.getElementType() == f64Ty ||
206  innerArrayTy.getElementType() == f32Ty)) {
207  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208  idx < innerSize; idx++) {
209  result.push_back(LLVM::ExtractElementOp::create(
210  b, toUse,
211  LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
212  }
213  continue;
214  }
215  result.push_back(toUse);
216  }
217  return result;
218 }
219 
220 /// Returns whether mbarrier object has shared memory address space.
221 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
222  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223  barrierType.getMemorySpace()));
224 }
225 
226 /// Returns the memory space attribute of the mbarrier object.
228  nvgpu::MBarrierGroupType barrierType) {
229  Attribute memorySpace = {};
230  if (isMbarrierShared(barrierType)) {
231  memorySpace =
232  IntegerAttr::get(IntegerType::get(context, 64),
233  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
234  }
235  return memorySpace;
236 }
237 
238 /// Returns memref type of the mbarrier object. The type is defined in the
239 /// MBarrierGroupType.
240 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
241  nvgpu::MBarrierGroupType barrierType) {
242  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
243  MemRefLayoutAttrInterface layout;
244  return MemRefType::get({barrierType.getNumBarriers()},
245  IntegerType::get(context, 64), layout, memorySpace);
246 }
247 
248 namespace {
249 
250 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
252 
253  LogicalResult
254  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
255  ConversionPatternRewriter &rewriter) const override {
256  MLIRContext *ctx = getContext();
257  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
258 
259  // The result type of ldmatrix will always be a struct of 32bit integer
260  // registers if more than one 32bit value is returned. Otherwise, the result
261  // is a single i32. The result type of the GPU operation is always a vector
262  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
263  // vector type of the result and always 32 bits long. We bitcast the result
264  // of the NVVM::LdMatrix to this vector type.
265  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266  if (!vectorResultType) {
267  return failure();
268  }
269  Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
270  vectorResultType.getElementType());
271 
272  int64_t num32BitRegs = vectorResultType.getDimSize(0);
273 
274  Type ldMatrixResultType;
275  if (num32BitRegs > 1) {
276  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
277  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
278  } else {
279  ldMatrixResultType = rewriter.getI32Type();
280  }
281 
282  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
283  Value srcPtr =
284  getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
285  adaptor.getSrcMemref(), adaptor.getIndices());
286  auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
287  Value ldMatrixResult = NVVM::LdMatrixOp::create(
288  b, ldMatrixResultType, srcPtr,
289  /*num=*/op.getNumTiles(),
290  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291  : NVVM::MMALayout::row,
292  /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
293 
294  // The ldmatrix operation returns either a single i32 value or a struct of
295  // i32 values. Here we unpack those values and cast them back to their
296  // actual vector type (still of width 32b) and repack them into a result
297  // struct.
298  Type finalResultType = typeConverter->convertType(vectorResultType);
299  Value result = LLVM::PoisonOp::create(b, finalResultType);
300  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301  Value i32Register =
302  num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
303  : ldMatrixResult;
304  Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
305  result = LLVM::InsertValueOp::create(b, result, casted, i);
306  }
307 
308  rewriter.replaceOp(op, result);
309  return success();
310  }
311 };
312 
313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314 /// enum).
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316  Type elType = getElementTypeOrSelf(t);
317  if (elType.isInteger(8))
318  return NVVM::MMATypes::s8;
319  if (elType.isInteger(4))
320  return NVVM::MMATypes::s4;
321  if (elType.isF16())
322  return NVVM::MMATypes::f16;
323  if (elType.isF64())
324  return NVVM::MMATypes::f64;
325  if (elType.isF32())
326  return NVVM::MMATypes::tf32;
327  return failure();
328 }
329 
330 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
332 
333  LogicalResult
334  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override {
336  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337  // Get the shapes of the MMAMatrix type being used. The shapes will
338  // choose which intrinsic this op will be lowered to.
339  VectorType aType = op.getMatrixA().getType();
340  VectorType bType = op.getMatrixA().getType();
341  VectorType cType = op.getMatrixC().getType();
342 
343  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344 
345  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347  if (aType.getElementType().isF32() && !tf32Enabled)
348  return failure();
349 
350  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351  if (failed(ptxTypeA))
352  return op->emitOpError("failed to deduce operand PTX types");
353  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354  if (failed(ptxTypeB))
355  return op->emitOpError("failed to deduce operand PTX types");
356  std::optional<NVVM::MMATypes> ptxTypeC =
357  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358  /*isAccumulator=*/true);
359  if (!ptxTypeC)
360  return op->emitError(
361  "could not infer the PTX type for the accumulator/result");
362 
363  // TODO: add an attribute to the op to customize this behavior.
364  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365  if (isa<IntegerType>(aType.getElementType()))
366  overflow = NVVM::MMAIntOverflow::satfinite;
367 
368  SmallVector<Value> matA =
369  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370  SmallVector<Value> matB =
371  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372  SmallVector<Value> matC =
373  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374 
375  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376  Type intrinsicResTy = inferIntrinsicResultType(
377  typeConverter->convertType(op->getResultTypes()[0]));
378  Value intrinsicResult =
379  NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
380  /*shape=*/gemmShape,
381  /*b1Op=*/std::nullopt,
382  /*intOverflow=*/overflow,
383  /*multiplicandPtxTypes=*/
384  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385  /*multiplicandLayouts=*/
386  std::array<NVVM::MMALayout, 2>{
387  NVVM::MMALayout::row, NVVM::MMALayout::col});
388  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389  desiredRetTy, intrinsicResult,
390  rewriter));
391  return success();
392  }
393 };
394 
395 struct ConvertNVGPUToNVVMPass
396  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397  using Base::Base;
398 
399  void runOnOperation() override {
402  LLVMTypeConverter converter(&getContext(), options);
403  IRRewriter rewriter(&getContext());
405  converter, [](gpu::AddressSpace space) -> unsigned {
406  switch (space) {
407  case gpu::AddressSpace::Global:
408  return static_cast<unsigned>(
410  case gpu::AddressSpace::Workgroup:
411  return static_cast<unsigned>(
413  case gpu::AddressSpace::Private:
414  return 0;
415  }
416  llvm_unreachable("unknown address space enum value");
417  return 0;
418  });
419  /// device-side async tokens cannot be materialized in nvvm. We just
420  /// convert them to a dummy i32 type in order to easily drop them during
421  /// conversion.
422  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
423  return converter.convertType(IntegerType::get(type.getContext(), 32));
424  });
425  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
426  Type elemType = type.getFragmented().getElementType();
427  int64_t sizeM = type.getFragmented().getDimSize(0);
428  int64_t sizeN = type.getFragmented().getDimSize(1);
429 
430  unsigned numMembers;
431  if (elemType.isF32() || elemType.isInteger(32))
432  numMembers = sizeN / 2;
433  else if (elemType.isF16())
434  numMembers = sizeN / 4;
435  else
436  llvm_unreachable("unsupported type for warpgroup accumulator");
437 
438  SmallVector<Type> innerStructBody;
439  for (unsigned i = 0; i < numMembers; i++)
440  innerStructBody.push_back(elemType);
441  auto innerStructType =
442  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
443 
444  SmallVector<Type> structBody;
445  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
446  structBody.push_back(innerStructType);
447 
448  auto convertedType =
449  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
450  return converter.convertType(convertedType);
451  });
452  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
453  return converter.convertType(IntegerType::get(type.getContext(), 64));
454  });
455  converter.addConversion(
456  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
457  return converter.convertType(IntegerType::get(type.getContext(), 64));
458  });
459  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
460  return converter.convertType(
461  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
462  });
463  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
464  return LLVM::LLVMPointerType::get(type.getContext());
465  });
468  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
469  target.addLegalDialect<::mlir::arith::ArithDialect>();
470  target.addLegalDialect<::mlir::memref::MemRefDialect>();
471  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
473  converter, patterns, target);
474  if (failed(applyPartialConversion(getOperation(), target,
475  std::move(patterns))))
476  signalPassFailure();
477  }
478 };
479 
480 /// Returns the constraints for the sparse MMA inline assembly instruction.
481 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
482  unsigned matBSize,
483  unsigned matCSize) {
484  std::string str;
485  llvm::raw_string_ostream ss(str);
486  for (unsigned i = 0; i < matCSize; i++)
487  ss << "=r,";
488  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
489  ss << "r,";
490  // The final operand is for the sparsity metadata.
491  // The sparsity selector appears as direct literal.
492  ss << "r";
493  return str;
494 }
495 
496 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
497 /// the given parameters. Note that this function doesn't do any validation,
498 /// it's expected that the provided parameters correspond to a valid
499 /// instruction.
500 static std::string buildMmaSparseAsmString(
501  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
502  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
503  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
504  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
505  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
506  return NVVM::stringifyMMATypes(ptxType);
507  };
508 
509  std::string asmStr;
510  llvm::raw_string_ostream ss(asmStr);
511  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
512  << shape[2] << ".row.col.";
513 
514  if (overflow)
515  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
516 
517  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
518  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
519  unsigned asmArgIdx = 0;
520 
521  // The operand string is structured into sections `{matC elements...},
522  // {matA elements...}, {matB elements...}, {matC elements}`.
523  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
524  ss << "{";
525  for (unsigned i = 0; i < arrSize; i++)
526  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
527  ss << "},";
528  }
529  ss << "$" << asmArgIdx++ << ",";
530  assert(metaDataSelector <= 1);
531  ss << "0x" << metaDataSelector << ";";
532  return asmStr;
533 }
534 
535 /// Builds an inline assembly operation corresponding to the specified MMA
536 /// sparse sync operation.
537 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
538  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
539  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
540  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
541  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
542  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
543  Type intrinsicResultType) {
544  auto asmDialectAttr =
545  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
546 
547  const unsigned matASize = unpackedAData.size();
548  const unsigned matBSize = unpackedB.size();
549  const unsigned matCSize = unpackedC.size();
550 
551  std::string asmStr = buildMmaSparseAsmString(
552  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
553  ptxTypeD, overflow, metadataSelector);
554  std::string constraintStr =
555  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
556 
557  SmallVector<Value> asmVals;
558  asmVals.reserve(matASize + matBSize + matCSize + 1);
559  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
560  llvm::append_range(asmVals, args);
561  asmVals.push_back(indexData);
562 
563  return LLVM::InlineAsmOp::create(b,
564  /*resultTypes=*/intrinsicResultType,
565  /*operands=*/asmVals,
566  /*asm_string=*/asmStr,
567  /*constraints=*/constraintStr,
568  /*has_side_effects=*/true,
569  /*is_align_stack=*/false,
571  /*asm_dialect=*/asmDialectAttr,
572  /*operand_attrs=*/ArrayAttr());
573 }
574 
575 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
576 struct NVGPUMmaSparseSyncLowering
577  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
579 
580  LogicalResult
581  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
582  ConversionPatternRewriter &rewriter) const override {
583  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
584  // Get the shapes of the MMAMatrix type being used. The shapes will
585  // choose which intrinsic this op will be lowered to.
586  VectorType aType = op.getMatrixA().getType();
587  VectorType bType = op.getMatrixB().getType();
588  VectorType cType = op.getMatrixC().getType();
589 
590  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
591  if (failed(ptxTypeA))
592  return op->emitOpError("failed to deduce operand PTX types");
593  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
594  if (failed(ptxTypeB))
595  return op->emitOpError("failed to deduce operand PTX types");
596  std::optional<NVVM::MMATypes> ptxTypeC =
597  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
598  /*isAccumulator=*/true);
599  if (!ptxTypeC)
600  return op->emitError(
601  "could not infer the PTX type for the accumulator/result");
602 
603  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
604  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
605  if (aType.getElementType().isF32() && !tf32Enabled)
606  return failure();
607 
608  // TODO: add an attribute to the op to customize this behavior.
609  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
610  if (isa<IntegerType>(aType.getElementType()))
611  overflow = NVVM::MMAIntOverflow::satfinite;
612 
613  SmallVector<Value> matA =
614  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
615  SmallVector<Value> matB =
616  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
617  SmallVector<Value> matC =
618  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
619 
620  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
621  Type intrinsicResTy = inferIntrinsicResultType(
622  typeConverter->convertType(op->getResultTypes()[0]));
623 
624  // Bitcast the sparse metadata from vector<2xf16> to an i32.
625  Value sparseMetadata = adaptor.getSparseMetadata();
626  if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
627  return op->emitOpError() << "Expected metadata type to be LLVM "
628  "VectorType of 2 i16 elements";
629  sparseMetadata =
630  LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
631 
632  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
633  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
634  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
635  intrinsicResTy);
636  if (failed(intrinsicResult))
637  return failure();
638 
639  assert((*intrinsicResult).getNumResults() == 1 &&
640  "expected inline asm op returns a single LLVM struct type");
641  rewriter.replaceOp(
642  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
643  (*intrinsicResult)->getResult(0), rewriter));
644  return success();
645  }
646 };
647 
648 struct NVGPUAsyncCopyLowering
649  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
651  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
652 
653  LogicalResult
654  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
655  ConversionPatternRewriter &rewriter) const override {
656  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
657  Location loc = op.getLoc();
658  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
659  Value dstPtr =
660  getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
661  adaptor.getDst(), adaptor.getDstIndices());
662  FailureOr<unsigned> dstAddressSpace =
663  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
664  if (failed(dstAddressSpace))
665  return rewriter.notifyMatchFailure(
666  loc, "destination memref address space not convertible to integer");
667 
668  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
669  FailureOr<unsigned> srcAddressSpace =
670  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
671  if (failed(srcAddressSpace))
672  return rewriter.notifyMatchFailure(
673  loc, "source memref address space not convertible to integer");
674 
675  Value scrPtr =
676  getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
677  adaptor.getSrcIndices());
678  // Intrinsics takes a global pointer so we need an address space cast.
679  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
681  scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
682  int64_t dstElements = adaptor.getDstElements().getZExtValue();
683  int64_t sizeInBytes =
684  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
685  // When the optional SrcElements argument is *not* present, the regular
686  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
687  // memory) to fill DstElements number of elements in the destination
688  // (shared memory).
689  Value srcBytes = adaptor.getSrcElements();
690  if (srcBytes) {
691  // When the optional SrcElements argument is present, the source (global
692  // memory) of CpAsyncOp is read only for SrcElements number of elements.
693  // The rest of the DstElements in the destination (shared memory) are
694  // filled with zeros.
695  Value c3I32 =
696  LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
697  Value bitwidth = LLVM::ConstantOp::create(
698  b, b.getI32Type(),
699  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
700  Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
701  srcBytes = LLVM::LShrOp::create(
702  b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
703  }
704  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
705  // 16 dst bytes.
706  NVVM::LoadCacheModifierKind cacheModifier =
707  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
708  ? NVVM::LoadCacheModifierKind::CG
709  : NVVM::LoadCacheModifierKind::CA;
710 
711  NVVM::CpAsyncOp::create(
712  b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
713  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
714  srcBytes);
715 
716  // Drop the result token.
717  Value zero =
718  LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
719  rewriter.getI32IntegerAttr(0));
720  rewriter.replaceOp(op, zero);
721  return success();
722  }
723 };
724 
725 struct NVGPUAsyncCreateGroupLowering
726  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
728  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
729 
730  LogicalResult
731  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
732  ConversionPatternRewriter &rewriter) const override {
733  NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
734  // Drop the result token.
735  Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
736  IntegerType::get(op.getContext(), 32),
737  rewriter.getI32IntegerAttr(0));
738  rewriter.replaceOp(op, zero);
739  return success();
740  }
741 };
742 
743 struct NVGPUAsyncWaitLowering
744  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
746  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
747 
748  LogicalResult
749  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
750  ConversionPatternRewriter &rewriter) const override {
751  // If numGroup is not present pick 0 as a conservative correct value.
752  int32_t numGroups = adaptor.getNumGroups().value_or(0);
753  NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
754  rewriter.eraseOp(op);
755  return success();
756  }
757 };
758 
759 /// Creates mbarrier object in shared memory
760 struct NVGPUMBarrierCreateLowering
761  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
763 
764  template <typename moduleT>
765  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
766  Operation *funcOp, moduleT moduleOp,
767  MemRefType barrierType) const {
768  SymbolTable symbolTable(moduleOp);
769  OpBuilder::InsertionGuard guard(rewriter);
770  rewriter.setInsertionPoint(&moduleOp.front());
771  auto global = memref::GlobalOp::create(
772  rewriter, funcOp->getLoc(), "__mbarrier",
773  /*sym_visibility=*/rewriter.getStringAttr("private"),
774  /*type=*/barrierType,
775  /*initial_value=*/ElementsAttr(),
776  /*constant=*/false,
777  /*alignment=*/rewriter.getI64IntegerAttr(8));
778  symbolTable.insert(global);
779  return global;
780  }
781 
782  LogicalResult
783  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
784  ConversionPatternRewriter &rewriter) const override {
785  Operation *funcOp = op->getParentOp();
786  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
787  rewriter.getContext(), op.getBarriers().getType());
788 
789  memref::GlobalOp global;
790  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
791  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
792  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
793  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
794 
795  rewriter.setInsertionPoint(op);
796  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
797  global.getName());
798  return success();
799  }
800 };
801 
802 /// Base class for lowering mbarrier operations to nvvm intrinsics.
803 template <typename SourceOp>
804 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
805 public:
807  /// Returns the base pointer of the mbarrier object.
808  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
809  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
810  Value mbarId,
811  ConversionPatternRewriter &rewriter) const {
812  MemRefType mbarrierMemrefType =
813  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
815  rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
816  }
817 };
818 
819 struct NVGPUMBarrierGetLowering
820  : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
821  using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
822 
823  LogicalResult
824  matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
825  ConversionPatternRewriter &rewriter) const override {
826  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
827  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
828  rewriter.setInsertionPoint(op);
829  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
830  adaptor.getMbarId(), rewriter);
831  Type resType = op.getMbarrierPointer().getType();
832  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
833  return success();
834  }
835 };
836 
837 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
838 struct NVGPUMBarrierInitLowering
839  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
840  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
841 
842  LogicalResult
843  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
844  ConversionPatternRewriter &rewriter) const override {
845  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
846  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
847  rewriter.setInsertionPoint(op);
848  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
849  adaptor.getMbarId(), rewriter);
850  Value count = truncToI32(b, adaptor.getCount());
851  if (isMbarrierShared(mbarrierType)) {
852  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
853  op, barrier, count, adaptor.getPredicate());
854  } else {
855  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
856  adaptor.getPredicate());
857  }
858  return success();
859  }
860 };
861 
862 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
863 struct NVGPUMBarrierArriveLowering
864  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
865  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
866  LogicalResult
867  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
868  ConversionPatternRewriter &rewriter) const override {
869  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
870  Value barrier =
871  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
872  adaptor.getMbarId(), rewriter);
873  Type tokenType = getTypeConverter()->convertType(
874  nvgpu::MBarrierTokenType::get(op->getContext()));
875  if (isMbarrierShared(op.getBarriers().getType())) {
876  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
877  barrier);
878  } else {
879  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
880  barrier);
881  }
882  return success();
883  }
884 };
885 
886 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
887 /// `nvvm.mbarrier.arrive.nocomplete`
888 struct NVGPUMBarrierArriveNoCompleteLowering
889  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
890  using MBarrierBasePattern<
891  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
892  LogicalResult
893  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
894  ConversionPatternRewriter &rewriter) const override {
895  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
896  Value barrier =
897  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
898  adaptor.getMbarId(), rewriter);
899  Type tokenType = getTypeConverter()->convertType(
900  nvgpu::MBarrierTokenType::get(op->getContext()));
901  Value count = truncToI32(b, adaptor.getCount());
902  if (isMbarrierShared(op.getBarriers().getType())) {
903  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
904  op, tokenType, barrier, count);
905  } else {
906  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
907  op, tokenType, barrier, count);
908  }
909  return success();
910  }
911 };
912 
913 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
914 struct NVGPUMBarrierTestWaitLowering
915  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
916  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
917  LogicalResult
918  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
919  ConversionPatternRewriter &rewriter) const override {
920  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
921  Value barrier =
922  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
923  adaptor.getMbarId(), rewriter);
924  Type retType = rewriter.getI1Type();
925  if (isMbarrierShared(op.getBarriers().getType())) {
926  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
927  op, retType, barrier, adaptor.getToken());
928  } else {
929  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
930  op, retType, barrier, adaptor.getToken());
931  }
932  return success();
933  }
934 };
935 
936 struct NVGPUMBarrierArriveExpectTxLowering
937  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
938  using MBarrierBasePattern<
939  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
940  LogicalResult
941  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
942  ConversionPatternRewriter &rewriter) const override {
943  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
944  Value barrier =
945  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
946  adaptor.getMbarId(), rewriter);
947  Value txcount = truncToI32(b, adaptor.getTxcount());
948 
949  if (isMbarrierShared(op.getBarriers().getType())) {
950  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
951  op, barrier, txcount, adaptor.getPredicate());
952  return success();
953  }
954 
955  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
956  op, barrier, txcount, adaptor.getPredicate());
957  return success();
958  }
959 };
960 
961 struct NVGPUMBarrierTryWaitParityLowering
962  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
963  using MBarrierBasePattern<
964  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
965  LogicalResult
966  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
967  ConversionPatternRewriter &rewriter) const override {
968  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
969  Value barrier =
970  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
971  adaptor.getMbarId(), rewriter);
972  Value ticks = truncToI32(b, adaptor.getTicks());
973  Value phase =
974  LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
975 
976  if (isMbarrierShared(op.getBarriers().getType())) {
977  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
978  op, barrier, phase, ticks);
979  return success();
980  }
981 
982  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
983  phase, ticks);
984  return success();
985  }
986 };
987 
988 struct NVGPUTmaAsyncLoadOpLowering
989  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
990  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
991  LogicalResult
992  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
993  ConversionPatternRewriter &rewriter) const override {
994  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
995  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
996  Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
997  adaptor.getDst(), {});
998  Value barrier =
999  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1000  adaptor.getMbarId(), rewriter);
1001 
1002  SmallVector<Value> coords = adaptor.getCoordinates();
1003  for (auto [index, value] : llvm::enumerate(coords)) {
1004  coords[index] = truncToI32(b, value);
1005  }
1006  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
1007  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1008  ValueRange{}, adaptor.getMulticastMask(), Value{},
1009  adaptor.getPredicate());
1010  return success();
1011  }
1012 };
1013 
1014 struct NVGPUTmaAsyncStoreOpLowering
1015  : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1016  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1017  LogicalResult
1018  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1019  ConversionPatternRewriter &rewriter) const override {
1020  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1021  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1022  Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1023  adaptor.getSrc(), {});
1024  SmallVector<Value> coords = adaptor.getCoordinates();
1025  for (auto [index, value] : llvm::enumerate(coords)) {
1026  coords[index] = truncToI32(b, value);
1027  }
1028 
1029  // TODO: Enhance the NVGPU Op for other modes too
1030  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1031  op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1032  NVVM::TMAStoreMode::TILE, // default is TILE mode
1033  adaptor.getPredicate());
1034  return success();
1035  }
1036 };
1037 
1038 struct NVGPUGenerateWarpgroupDescriptorLowering
1039  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1040  using ConvertOpToLLVMPattern<
1041  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1042 
1043  LogicalResult
1044  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1045  ConversionPatternRewriter &rewriter) const override {
1046 
1047  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1048 
1049  nvgpu::TensorMapSwizzleKind swizzleKind =
1050  op.getTensorMap().getType().getSwizzle();
1051 
1052  unsigned layout =
1053  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1054  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1055  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1056  : 1;
1057  unsigned swizzle =
1058  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1059  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1060  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1061  : 0;
1062 
1063  auto ti64 = b.getIntegerType(64);
1064  auto makeConst = [&](uint64_t index) -> Value {
1065  return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
1066  };
1067  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1068  return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1069  };
1070  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1071  return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1072  };
1073  auto insertBit = [&](Value desc, Value val, int startBit) {
1074  return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1075  };
1076 
1077  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1078  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1079  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1080  uint64_t offsetVal = 0;
1081 
1082  Value strideDim = makeConst(strideDimVal);
1083  Value leadDim = makeConst(leadDimVal);
1084 
1085  Value baseAddr = getStridedElementPtr(
1086  rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1087  adaptor.getTensor(), {});
1088  Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1089  // Just use 14 bits for base address
1090  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1091 
1092  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1093  startLeadBit = 16, startBaseAddrBit = 0;
1094  Value dsc = makeConst(0);
1095  // // [62,64) swizzle type
1096  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1097  // // [49,52) base_offset
1098  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1099  // // [32,46) stride
1100  dsc = insertBit(dsc, strideDim, startStrideBit);
1101  // // [16,30) leading dimension
1102  dsc = insertBit(dsc, leadDim, startLeadBit);
1103  // // [0,14) start_address
1104  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1105 
1106  LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
1107  << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
1108  << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
1109  << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1110  << ")\n start_addr : " << baseAddr;
1111 
1112  rewriter.replaceOp(op, dsc);
1113  return success();
1114  }
1115 };
1116 
1117 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1118  return LLVM::ConstantOp::create(b, b.getIntegerType(64),
1119  b.getI32IntegerAttr(index));
1120 }
1121 
1122 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1123 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1124  // Enum is from CUDA driver API
1125  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1126  enum CUtensorMapDataTypeEnum {
1127  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1128  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1129  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1130  CU_TENSOR_MAP_DATA_TYPE_INT32,
1131  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1132  CU_TENSOR_MAP_DATA_TYPE_INT64,
1133  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1134  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1135  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1136  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1137  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1138  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1139  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1140  };
1141 
1142  if (type.isUnsignedInteger(8))
1143  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1144  if (type.isUnsignedInteger(16))
1145  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1146  if (type.isUnsignedInteger(32))
1147  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1148  if (type.isUnsignedInteger(64))
1149  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1150  if (type.isSignlessInteger(32))
1151  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1152  if (type.isSignlessInteger(64))
1153  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1154  if (type.isF16())
1155  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1156  if (type.isF32())
1157  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1158  if (type.isF64())
1159  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1160  if (type.isBF16())
1161  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1162 
1163  llvm_unreachable("Not supported data type");
1164 }
1165 
1166 struct NVGPUTmaCreateDescriptorOpLowering
1167  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1168  using ConvertOpToLLVMPattern<
1169  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1170  LogicalResult
1171  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1172  ConversionPatternRewriter &rewriter) const override {
1173  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1174  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1175  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1176 
1177  Value tensorElementType =
1178  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1179  auto promotedOperands = getTypeConverter()->promoteOperands(
1180  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1181 
1182  Value boxArrayPtr = LLVM::AllocaOp::create(
1183  b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1184  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1185  Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1186  boxArrayPtr, makeI64Const(b, index));
1187  LLVM::StoreOp::create(b, value, gep);
1188  }
1189 
1190  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1191  // Set Arguments for the function call
1192  SmallVector<Value> arguments;
1193  arguments.push_back(promotedOperands[0]); // rank
1194  arguments.push_back(promotedOperands[1]); // descriptor
1195  arguments.push_back(tensorElementType); // data type
1196  arguments.push_back(
1197  makeI64Const(b, (int)desc.getInterleave())); // interleave
1198  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1199  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1200  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1201  arguments.push_back(boxArrayPtr); // box dimensions
1202 
1203  // Set data types of the arguments
1204  SmallVector<Type> argTypes = {
1205  llvmInt64Type, /* int64_t tensorRank */
1206  llvmPointerType, /* ptr */
1207  llvmInt64Type, /* int64_t */
1208  llvmInt64Type, /* int64_t */
1209  llvmInt64Type, /* int64_t */
1210  llvmInt64Type, /* int64_t */
1211  llvmInt64Type, /* int64_t */
1212  llvmPointerType /* ptr */
1213  };
1214  FunctionCallBuilder hostRegisterCallBuilder = {
1215  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1216  Value tensorMap =
1217  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1218 
1219  rewriter.replaceOp(op, tensorMap);
1220  return success();
1221  }
1222 };
1223 
1224 struct NVGPUWarpgroupMmaOpLowering
1225  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1227 
1228  /// This is a helper class to generate required NVVM Ops for warp-group level
1229  /// matrix multiplication.
1230  /// When the given GEMM shape is larger than the shape of
1231  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1232  /// Op(s), group and execute them asynchronously. The class also handles
1233  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1234  /// create descriptors for each instruction.
1235  ///
1236  /// For example this is the case when the shape of GEMM is 128x128x128
1237  ///
1238  /// nvvm.wgmma.fence.aligned
1239  ///
1240  /// nvvm.wgmma.mma.async descA, descB
1241  /// iterate(descA, descB)
1242  /// nvvm.wgmma.mma.async descA, descB
1243  /// [6x times more]
1244  ///
1245  /// nvvm.wgmma.group.sync.aligned
1246  /// nvvm.wgmma.wait.group.sync [groupId]
1247  ///
1248  class WarpgroupGemm {
1249  nvgpu::WarpgroupMmaOp op;
1251  OpAdaptor adaptor;
1252 
1253  // Entire shape of the given Op
1254  int64_t totalM, totalN, totalK;
1255 
1256  // Shape of one wgmma instruction
1257  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1258 
1259  // Iteration counts for GEMM
1260  int iterationM = 0, iterationN = 0, iterationK = 0;
1261 
1262  /// The function returns the shape of wgmma instruction that is defined in
1263  /// PTX programming guide.
1264  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1265  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1266  wgmmaM = 64;
1267  wgmmaN = sizeN;
1268  if (inputElemType.isTF32()) {
1269  wgmmaK = 8;
1270  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1271  wgmmaK = 16;
1272  } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1273  inputElemType.isInteger(16)) {
1274  wgmmaK = 32;
1275  } else if (inputElemType.isInteger(1)) {
1276  wgmmaK = 256;
1277  } else {
1278  llvm_unreachable("msg: not supported K shape");
1279  }
1280  LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1281  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
1282  }
1283 
1284  /// Generates WGMMATypesAttr from MLIR Type
1285  NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1286  bool useF32 = false) const {
1287  auto getWgmmaType = [=](Type elemType) {
1288  if (elemType.isF32() || elemType.isTF32())
1289  return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1290  if (elemType.isF16())
1291  return NVVM::WGMMATypes::f16;
1292  if (elemType.isBF16())
1293  return NVVM::WGMMATypes::bf16;
1294  if (isa<Float8E4M3FNType>(elemType))
1295  return NVVM::WGMMATypes::e4m3;
1296  if (isa<Float8E5M2Type>(elemType))
1297  return NVVM::WGMMATypes::e5m2;
1298  if (elemType.isInteger(1))
1299  return NVVM::WGMMATypes::b1;
1300  if (elemType.isInteger(8))
1301  return NVVM::WGMMATypes::s8;
1302  if (elemType.isUnsignedInteger(8))
1303  return NVVM::WGMMATypes::u8;
1304  if (elemType.isInteger(32))
1305  return NVVM::WGMMATypes::s32;
1306  llvm_unreachable("unsupported type");
1307  };
1308  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1309  }
1310 
1311  /// Generates layout attribute for the input matrix for wgmma instruction
1312  NVVM::MMALayoutAttr
1313  generateWgmmaLayout(std::optional<bool> transpose) const {
1314  if (transpose.value_or(false))
1315  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1316  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1317  }
1318 
1319  /// Generates shape attribute for wgmma instruction
1320  NVVM::MMAShapeAttr generateWgmmaShape() const {
1321  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1322  }
1323 
1324  /// Generates scale attributes of output matrix for wgmma instruction
1325  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1326  return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1327  NVVM::WGMMAScaleOut::one);
1328  }
1329  /// Generates scale attributes of input matrix for wgmma instruction
1330  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1331  return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1332  NVVM::WGMMAScaleIn::one);
1333  }
1334 
1335  /// Basic function to generate Add
1336  Value makeAdd(Value lhs, Value rhs) {
1337  return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1338  };
1339 
1340  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1341  /// Currently, it only handles row-major.
1342  ///
1343  /// It moves the pointer like below for [128][64] size:
1344  /// +2 +4 +6
1345  /// ↓ ↓ ↓
1346  /// descA ---> +--+--+--+--+
1347  /// |->|->|->|->|
1348  /// | | | | |
1349  /// | | | | |
1350  /// | | | | |
1351  /// descA+512---> +-----------+
1352  /// | | | | |
1353  /// | | | | |
1354  /// | | | | |
1355  /// | | | | |
1356  /// +-----------+
1357  ///
1358  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1359  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1360  Type elemA = matrixTypeA.getElementType();
1361  int byte = elemA.getIntOrFloatBitWidth() / 8;
1362  int tileShapeA = matrixTypeA.getDimSize(1);
1363  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1364  incrementVal = incrementVal >> exclude4LSB;
1365  LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
1366  << "] [wgmma descriptors] Descriptor A + " << incrementVal
1367  << " | \t ";
1368  if (!incrementVal)
1369  return desc;
1370  return makeAdd(desc, makeI64Const(b, incrementVal));
1371  }
1372 
1373  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1374  /// Currently, it only handles column-major.
1375  ///
1376  /// It moves the pointer like below for [128][64] size:
1377  /// descB ---> +--+--+--+--+--+--+--+--+
1378  /// |↓ | | | | | | | |
1379  /// |↓ | | | | | | | |
1380  /// |↓ | | | | | | | |
1381  /// |↓ | | | | | | | |
1382  /// +--+--+--+--+--+--+--+--+
1383  ///
1384  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1385  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1386  Type elemB = matrixTypeB.getElementType();
1387  int byte = elemB.getIntOrFloatBitWidth() / 8;
1388  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1389  incrementVal = incrementVal >> exclude4LSB;
1390  LDBG() << "Descriptor B + " << incrementVal;
1391  if (!incrementVal)
1392  return desc;
1393  return makeAdd(desc, makeI64Const(b, incrementVal));
1394  }
1395 
1396  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1397  /// descriptors and arranges them based on induction variables: i, j, and k.
1398  Value generateWgmma(int i, int j, int k, Value matrixC) {
1399  LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1400  << "(A[" << (iterationM * wgmmaM) << ":"
1401  << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
1402  << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
1403  << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
1404  << "][" << 0 << ":" << wgmmaN << "])";
1405 
1406  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1407  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1408 
1409  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1410  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1411 
1412  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1413  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1414 
1415  Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1416  NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1417 
1418  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1419  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1420  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1421  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1422  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1423 
1424  auto overflow = NVVM::MMAIntOverflowAttr::get(
1425  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1426 
1427  return NVVM::WgmmaMmaAsyncOp::create(
1428  b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
1429  itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1430  overflow);
1431  }
1432 
1433  /// Generates multiple wgmma instructions to complete the given GEMM shape
1434  Value generateWgmmaGroup() {
1435  Value wgmmaResult =
1436  LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1437 
1438  // Perform GEMM
1439  SmallVector<Value> wgmmaResults;
1440  for (int i = 0; i < iterationM; ++i) {
1441  Value matrixC =
1442  LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1443  for (int j = 0; j < iterationN; ++j)
1444  for (int k = 0; k < iterationK; ++k)
1445  matrixC = generateWgmma(i, j, k, matrixC);
1446  wgmmaResults.push_back(matrixC);
1447  }
1448  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1449  wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
1450  wgmmaResult, matrix, idx);
1451  }
1452  return wgmmaResult;
1453  }
1454 
1455  public:
1456  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1457  OpAdaptor adaptor)
1458  : op(op), b(b), adaptor(adaptor) {
1459  // Find the entire GEMM Shape
1460  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1461  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1462  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1463  LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
1464  << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
1465  << "] ---===";
1466 
1467  // Find the shape for one wgmma instruction
1468  findWgmmaShape(
1469  totalM, totalN,
1470  op.getDescriptorA().getType().getTensor().getElementType());
1471 
1472  // Iterations counts to complete the given shape with wgmma shape
1473  iterationM = totalM / wgmmaM;
1474  iterationN = totalN / wgmmaN;
1475  iterationK = totalK / wgmmaK;
1476  }
1477 
1478  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1479  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1480  /// instructions and group synchronization, as well as waiting
1481  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1482  /// (WgmmaWaitGroupSyncOp) after the instructions.
1483  Value generateWarpgroupMma() {
1484  NVVM::WgmmaFenceAlignedOp::create(b);
1485  Value wgmmaResult = generateWgmmaGroup();
1486  NVVM::WgmmaGroupSyncAlignedOp::create(b);
1487  NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1488  return wgmmaResult;
1489  }
1490  };
1491  LogicalResult
1492  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1493  ConversionPatternRewriter &rewriter) const override {
1494  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1495 
1496  // Step 1. Build a helper class
1497  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1498 
1499  // Step 2. Get the entire GEMM Shape
1500  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1501 
1502  // Step 3. Replace fragmented result struct with the op results
1503  rewriter.replaceOp(op, wgmmaResult);
1504  return success();
1505  }
1506 };
1507 
1508 struct NVGPUWarpgroupMmaStoreOpLowering
1509  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1510  using ConvertOpToLLVMPattern<
1511  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1512 
1513  /// This function stores a fragmented register matrix owned by a warp group
1514  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1515  /// of a struct.
1516  /// Here is what each threads (T) holds, each `d` is struct value with a
1517  /// number.
1518  ///
1519  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1520  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1521  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1522  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1523  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1524  ///
1525  /// Matrix-D:
1526  /// +______________________________________________________________________+
1527  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1528  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1529  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1530  /// ..| .........|.........|.........|.........|........|...........|........|
1531  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1532  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1533  /// ..| .........|.........|.........|.........|........|...........|........|
1534  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1535  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1536  /// ..| .........|.........|.........|.........|........|...........|........|
1537  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1538  /// ..| .........|.........|.........|.........|........|...........|........|
1539  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1540  /// ..| .........|.........|.........|.........|........|...........|........|
1541  /// +______________________________________________________________________+
1542  ///
1543  /// \param rewriter: The pattern rewriter.
1544  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1545  /// matrix). It is holded by a thread and a struct with 64 elements.
1546  /// \param dstMemref: The memref where the registers will be stored.
1547  /// \param offset: the offset within the memref where the registers will be
1548  /// stored.
1549  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1550  TypedValue<MemRefType> dstMemref,
1551  int offset) const {
1552  Type i32 = b.getI32Type();
1553 
1554  auto makeConst = [&](int32_t index) -> Value {
1555  return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
1556  };
1557  Value c1 = makeConst(1);
1558  Value c2 = makeConst(2);
1559  Value c4 = makeConst(4);
1560  Value c8 = makeConst(8);
1561  Value c16 = makeConst(16);
1562  Value warpSize = makeConst(kWarpSize);
1563 
1564  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1565  return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
1566  };
1567  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1568  return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1569  };
1570 
1571  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1573  Type it = b.getIndexType();
1574  Value idx = arith::IndexCastOp::create(b, it, x);
1575  Value idy0 = arith::IndexCastOp::create(b, it, y);
1576  Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1577  Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1578  Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1579  memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
1580  memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
1581  };
1582 
1583  Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1584  Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1585  Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1586  Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1587  Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1588 
1589  Value tj = makeMul(lane4modId, c2);
1590  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1591  if (offset)
1592  ti = makeAdd(ti, makeConst(offset));
1593 
1594  auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1595 
1596  // Number of 32-bit registers owns per thread
1597  constexpr unsigned numAdjacentRegisters = 2;
1598  // Number of 8x8 matrices one below another per warp
1599  constexpr unsigned numStackedMatrices = 2;
1600 
1601  size_t storeCount = (structType.getBody().size() /
1602  (numStackedMatrices * numAdjacentRegisters));
1603 
1604  for (size_t i = 0; i < numStackedMatrices; ++i) {
1605  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1606  for (size_t j = 0; j < storeCount; ++j) {
1607  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1608  size_t structIndex = (i * numAdjacentRegisters) +
1609  (j * (numStackedMatrices * numAdjacentRegisters));
1610  makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1611  }
1612  }
1613  }
1614 
1615  LogicalResult
1616  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1617  ConversionPatternRewriter &rewriter) const override {
1618  int offset = 0;
1619  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1620  Value matriDValue = adaptor.getMatrixD();
1621  auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1622  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1623  auto structType = cast<LLVM::LLVMStructType>(matrixD);
1624  Value innerStructValue =
1625  LLVM::ExtractValueOp::create(b, matriDValue, idx);
1626  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1627  offset += structType.getBody().size();
1628  }
1629  rewriter.eraseOp(op);
1630  return success();
1631  }
1632 };
1633 
1634 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1635  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1636  using ConvertOpToLLVMPattern<
1637  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1638  LogicalResult
1639  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1640  ConversionPatternRewriter &rewriter) const override {
1641  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1642  LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1643  getTypeConverter()->convertType(op.getMatrixC().getType()));
1644  Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1645  .getBody()
1646  .front();
1647  Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
1648  Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1649  SmallVector<Value> innerStructs;
1650  // Unpack the structs and set all values to zero
1651  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1652  auto structType = cast<LLVM::LLVMStructType>(s);
1653  Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1654  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1655  structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1656  zero, ArrayRef<int64_t>({i}));
1657  }
1658  innerStructs.push_back(structValue);
1659  }
1660  // Pack the inner structs into a single struct
1661  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1662  packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
1663  packStruct, matrix, idx);
1664  }
1665  rewriter.replaceOp(op, packStruct);
1666  return success();
1667  }
1668 };
1669 
1670 struct NVGPUTmaFenceOpLowering
1671  : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1673  LogicalResult
1674  matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1675  ConversionPatternRewriter &rewriter) const override {
1676  MLIRContext *ctx = op.getContext();
1677  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1678  auto i32Ty = b.getI32Type();
1679  Value tensormapSize =
1680  LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
1681 
1682  auto memscope =
1683  NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1684 
1685  rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1686  op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1687 
1688  return success();
1689  }
1690 };
1691 
1692 struct NVGPUTmaPrefetchOpLowering
1693  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1695  LogicalResult
1696  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1697  ConversionPatternRewriter &rewriter) const override {
1698  rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1699  op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
1700  adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1701  /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
1702  return success();
1703  }
1704 };
1705 
1706 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1708  LogicalResult
1709  matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1710  ConversionPatternRewriter &rewriter) const override {
1711  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1712  auto i64Ty = b.getI64Type();
1713  auto f32Ty = b.getF32Type();
1714  VectorType inTy = op.getIn().getType();
1715  // apply rcp.approx.ftz.f on each element in vector.
1716  auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1717  Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1718  int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1719  for (int i = 0; i < numElems; i++) {
1720  Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
1721  Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1722  Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1723  ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1724  }
1725  return ret1DVec;
1726  };
1727  if (inTy.getRank() == 1) {
1728  rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1729  return success();
1730  }
1732  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1733  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1734  OpAdaptor adaptor(operands);
1735  return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1736  },
1737  rewriter);
1738  }
1739 };
1740 } // namespace
1741 
1743  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1744  patterns.add<
1745  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1746  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1747  NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1748  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1749  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1750  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1751  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1752  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1753  NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1754  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1755  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1756  NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1757  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1758  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1759  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1760  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1761  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1762  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1763  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1764  NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1765 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
@ None
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:49
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:59
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:45
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
Definition: NVGPUToNVVM.cpp:98
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:60
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
FloatType getF16Type()
Definition: Builders.cpp:38
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
IntegerType getI8Type()
Definition: Builders.cpp:58
FloatType getF64Type()
Definition: Builders.cpp:44
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition: Pattern.cpp:64
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:654
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isTF32() const
Definition: Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:42
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.