MLIR  22.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/IR/Value.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 
36 namespace mlir {
37 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38 #include "mlir/Conversion/Passes.h.inc"
39 } // namespace mlir
40 
41 using namespace mlir;
42 
43 /// Number of bits that needs to be excluded when building matrix descriptor for
44 /// wgmma operations.
45 constexpr int exclude4LSB = 4;
46 
47 /// GPU has 32 bit registers, this function truncates values when larger width
48 /// is not needed.
50  Type type = value.getType();
51  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
52  if (type.getIntOrFloatBitWidth() <= 32)
53  return value;
54  return LLVM::TruncOp::create(b, b.getI32Type(), value);
55 }
56 
57 /// Returns the type for the intrinsic given the vectorResultType of the
58 /// `gpu.mma.sync` operation.
59 static Type inferIntrinsicResultType(Type vectorResultType) {
60  MLIRContext *ctx = vectorResultType.getContext();
61  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62  auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63  auto i32Ty = IntegerType::get(ctx, 32);
64  auto i32x2Ty = VectorType::get(2, i32Ty);
65  Type f64Ty = Float64Type::get(ctx);
66  Type f64x2Ty = VectorType::get(2, f64Ty);
67  Type f32Ty = Float32Type::get(ctx);
68  Type f32x2Ty = VectorType::get(2, f32Ty);
69  if (a.getElementType() == f16x2Ty) {
70  return LLVM::LLVMStructType::getLiteral(
71  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
72  }
73  if (a.getElementType() == i32x2Ty) {
74  return LLVM::LLVMStructType::getLiteral(
75  ctx,
76  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
77  }
78  if (a.getElementType() == f64x2Ty) {
79  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
80  }
81  if (a.getElementType() == f32x2Ty) {
82  return LLVM::LLVMStructType::getLiteral(
83  ctx,
84  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
85  }
86  if (a.getElementType() == VectorType::get(1, f32Ty)) {
87  return LLVM::LLVMStructType::getLiteral(
88  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
89  }
90  return vectorResultType;
91 }
92 
93 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
94 /// always an LLVM struct) into a fragment that is compatible with the vector
95 /// type of this operation. This involves extracting elements from the struct
96 /// and inserting them into an LLVM array. These extra data-movement
97 /// operations should be canonicalized away by the LLVM backend.
98 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
99  Type resultType, Value intrinsicResult,
100  RewriterBase &rewriter) {
101  MLIRContext *ctx = rewriter.getContext();
102  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
104  Type i32Ty = rewriter.getI32Type();
105  Type f32Ty = rewriter.getF32Type();
106  Type f64Ty = rewriter.getF64Type();
107  Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
108  Type i32x2Ty = VectorType::get(2, i32Ty);
109  Type f64x2Ty = VectorType::get(2, f64Ty);
110  Type f32x2Ty = VectorType::get(2, f32Ty);
111  Type f32x1Ty = VectorType::get(1, f32Ty);
112 
113  auto makeConst = [&](int32_t index) -> Value {
114  return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
115  rewriter.getI32IntegerAttr(index));
116  };
117 
118  if (arrayType) {
119  SmallVector<Value, 4> elements;
120 
121  // The intrinsic returns 32-bit wide elements in a form which can be
122  // directly bitcasted and inserted into the result vector.
123  if (arrayType.getElementType() == f16x2Ty ||
124  arrayType.getElementType() == f32x1Ty) {
125  for (unsigned i = 0; i < structType.getBody().size(); i++) {
126  Value el =
127  LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
128  el = rewriter.createOrFold<LLVM::BitcastOp>(
129  loc, arrayType.getElementType(), el);
130  elements.push_back(el);
131  }
132  }
133 
134  // The intrinsic returns i32, f64, and f32 values as individual scalars,
135  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
136  // need to extract them from the struct and pack them into the 64-bit wide
137  // rows of the vector result.
138  if (arrayType.getElementType() == i32x2Ty ||
139  arrayType.getElementType() == f64x2Ty ||
140  arrayType.getElementType() == f32x2Ty) {
141 
142  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
143  Value vec =
144  LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
145  Value x1 =
146  LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147  Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
148  i * 2 + 1);
149  vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
150  x1, makeConst(0));
151  vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec,
152  x2, makeConst(1));
153  elements.push_back(vec);
154  }
155  }
156 
157  // Create the final vectorized result.
158  Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159  for (const auto &el : llvm::enumerate(elements)) {
160  result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
161  el.index());
162  }
163  return result;
164  }
165 
166  return intrinsicResult;
167 }
168 
169 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
170 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
171 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
172 /// scalars of certain types. This function helps unpack the `vector` arguments
173 /// and cast them to the types expected by `nvvm.mma.sync`.
175  Value operand,
176  NVVM::MMATypes operandPtxType) {
177  SmallVector<Value> result;
178  Type i32Ty = b.getI32Type();
179  Type f64Ty = b.getF64Type();
180  Type f32Ty = b.getF32Type();
181  Type i64Ty = b.getI64Type();
182  Type i8x4Ty = VectorType::get(4, b.getI8Type());
183  Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
184  Type f32x1Ty = VectorType::get(1, f32Ty);
185  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
186 
187  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188  Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
189 
190  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
191  // scalar types.
192  if (arrayTy.getElementType() == i8x4Ty ||
193  arrayTy.getElementType() == i4x8Ty ||
194  (arrayTy.getElementType() == f32x1Ty &&
195  operandPtxType == NVVM::MMATypes::tf32)) {
196  result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
197  continue;
198  }
199 
200  // For some element types (i32, f32, f64), we need to unpack the inner
201  // vector/array type as well because the intrinsic expects individual
202  // scalars to be provided.
203  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205  innerArrayTy.getElementType() == f64Ty ||
206  innerArrayTy.getElementType() == f32Ty)) {
207  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208  idx < innerSize; idx++) {
209  result.push_back(LLVM::ExtractElementOp::create(
210  b, toUse,
211  LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx))));
212  }
213  continue;
214  }
215  result.push_back(toUse);
216  }
217  return result;
218 }
219 
220 /// Returns whether mbarrier object has shared memory address space.
221 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
222  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223  barrierType.getMemorySpace()));
224 }
225 
226 /// Returns the memory space attribute of the mbarrier object.
228  nvgpu::MBarrierGroupType barrierType) {
229  Attribute memorySpace = {};
230  if (isMbarrierShared(barrierType)) {
231  memorySpace =
232  IntegerAttr::get(IntegerType::get(context, 64),
233  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
234  }
235  return memorySpace;
236 }
237 
238 /// Returns memref type of the mbarrier object. The type is defined in the
239 /// MBarrierGroupType.
240 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
241  nvgpu::MBarrierGroupType barrierType) {
242  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
243  MemRefLayoutAttrInterface layout;
244  return MemRefType::get({barrierType.getNumBarriers()},
245  IntegerType::get(context, 64), layout, memorySpace);
246 }
247 
248 namespace {
249 
250 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
252 
253  LogicalResult
254  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
255  ConversionPatternRewriter &rewriter) const override {
256  MLIRContext *ctx = getContext();
257  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
258 
259  // The result type of ldmatrix will always be a struct of 32bit integer
260  // registers if more than one 32bit value is returned. Otherwise, the result
261  // is a single i32. The result type of the GPU operation is always a vector
262  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
263  // vector type of the result and always 32 bits long. We bitcast the result
264  // of the NVVM::LdMatrix to this vector type.
265  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266  if (!vectorResultType) {
267  return failure();
268  }
269  Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
270  vectorResultType.getElementType());
271 
272  int64_t num32BitRegs = vectorResultType.getDimSize(0);
273 
274  Type ldMatrixResultType;
275  if (num32BitRegs > 1) {
276  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
277  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
278  } else {
279  ldMatrixResultType = rewriter.getI32Type();
280  }
281 
282  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
283  Value srcPtr =
284  getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
285  adaptor.getSrcMemref(), adaptor.getIndices());
286  auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
287  Value ldMatrixResult = NVVM::LdMatrixOp::create(
288  b, ldMatrixResultType, srcPtr,
289  /*num=*/op.getNumTiles(),
290  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291  : NVVM::MMALayout::row,
292  /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
293 
294  // The ldmatrix operation returns either a single i32 value or a struct of
295  // i32 values. Here we unpack those values and cast them back to their
296  // actual vector type (still of width 32b) and repack them into a result
297  // struct.
298  Type finalResultType = typeConverter->convertType(vectorResultType);
299  Value result = LLVM::PoisonOp::create(b, finalResultType);
300  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301  Value i32Register =
302  num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
303  : ldMatrixResult;
304  Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
305  result = LLVM::InsertValueOp::create(b, result, casted, i);
306  }
307 
308  rewriter.replaceOp(op, result);
309  return success();
310  }
311 };
312 
313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314 /// enum).
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316  Type elType = getElementTypeOrSelf(t);
317  if (elType.isInteger(8))
318  return NVVM::MMATypes::s8;
319  if (elType.isInteger(4))
320  return NVVM::MMATypes::s4;
321  if (elType.isF16())
322  return NVVM::MMATypes::f16;
323  if (elType.isF64())
324  return NVVM::MMATypes::f64;
325  if (elType.isF32())
326  return NVVM::MMATypes::tf32;
327  return failure();
328 }
329 
330 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
332 
333  LogicalResult
334  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override {
336  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337  // Get the shapes of the MMAMatrix type being used. The shapes will
338  // choose which intrinsic this op will be lowered to.
339  VectorType aType = op.getMatrixA().getType();
340  VectorType bType = op.getMatrixA().getType();
341  VectorType cType = op.getMatrixC().getType();
342 
343  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344 
345  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347  if (aType.getElementType().isF32() && !tf32Enabled)
348  return failure();
349 
350  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351  if (failed(ptxTypeA))
352  return op->emitOpError("failed to deduce operand PTX types");
353  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354  if (failed(ptxTypeB))
355  return op->emitOpError("failed to deduce operand PTX types");
356  std::optional<NVVM::MMATypes> ptxTypeC =
357  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358  /*isAccumulator=*/true);
359  if (!ptxTypeC)
360  return op->emitError(
361  "could not infer the PTX type for the accumulator/result");
362 
363  // TODO: add an attribute to the op to customize this behavior.
364  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365  if (isa<IntegerType>(aType.getElementType()))
366  overflow = NVVM::MMAIntOverflow::satfinite;
367 
368  SmallVector<Value> matA =
369  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370  SmallVector<Value> matB =
371  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372  SmallVector<Value> matC =
373  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374 
375  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376  Type intrinsicResTy = inferIntrinsicResultType(
377  typeConverter->convertType(op->getResultTypes()[0]));
378  Value intrinsicResult =
379  NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
380  /*shape=*/gemmShape,
381  /*b1Op=*/std::nullopt,
382  /*intOverflow=*/overflow,
383  /*multiplicandPtxTypes=*/
384  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385  /*multiplicandLayouts=*/
386  std::array<NVVM::MMALayout, 2>{
387  NVVM::MMALayout::row, NVVM::MMALayout::col});
388  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389  desiredRetTy, intrinsicResult,
390  rewriter));
391  return success();
392  }
393 };
394 
395 struct ConvertNVGPUToNVVMPass
396  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397  using Base::Base;
398 
399  void runOnOperation() override {
402  LLVMTypeConverter converter(&getContext(), options);
403  IRRewriter rewriter(&getContext());
405  converter, [](gpu::AddressSpace space) -> unsigned {
406  switch (space) {
407  case gpu::AddressSpace::Global:
408  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
409  case gpu::AddressSpace::Workgroup:
410  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
411  case gpu::AddressSpace::Private:
412  return 0;
413  }
414  llvm_unreachable("unknown address space enum value");
415  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
416  });
417  /// device-side async tokens cannot be materialized in nvvm. We just
418  /// convert them to a dummy i32 type in order to easily drop them during
419  /// conversion.
420  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
421  return converter.convertType(IntegerType::get(type.getContext(), 32));
422  });
423  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
424  Type elemType = type.getFragmented().getElementType();
425  int64_t sizeM = type.getFragmented().getDimSize(0);
426  int64_t sizeN = type.getFragmented().getDimSize(1);
427 
428  unsigned numMembers;
429  if (elemType.isF32() || elemType.isInteger(32))
430  numMembers = sizeN / 2;
431  else if (elemType.isF16())
432  numMembers = sizeN / 4;
433  else
434  llvm_unreachable("unsupported type for warpgroup accumulator");
435 
436  SmallVector<Type> innerStructBody;
437  for (unsigned i = 0; i < numMembers; i++)
438  innerStructBody.push_back(elemType);
439  auto innerStructType =
440  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
441 
442  SmallVector<Type> structBody;
443  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
444  structBody.push_back(innerStructType);
445 
446  auto convertedType =
447  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
448  return converter.convertType(convertedType);
449  });
450  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
451  return converter.convertType(IntegerType::get(type.getContext(), 64));
452  });
453  converter.addConversion(
454  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
455  return converter.convertType(IntegerType::get(type.getContext(), 64));
456  });
457  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
458  return converter.convertType(
459  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
460  });
461  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
462  return LLVM::LLVMPointerType::get(type.getContext());
463  });
466  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
467  target.addLegalDialect<::mlir::arith::ArithDialect>();
468  target.addLegalDialect<::mlir::memref::MemRefDialect>();
469  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
471  converter, patterns, target);
472  if (failed(applyPartialConversion(getOperation(), target,
473  std::move(patterns))))
474  signalPassFailure();
475  }
476 };
477 
478 /// Returns the constraints for the sparse MMA inline assembly instruction.
479 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
480  unsigned matBSize,
481  unsigned matCSize) {
482  std::string str;
483  llvm::raw_string_ostream ss(str);
484  for (unsigned i = 0; i < matCSize; i++)
485  ss << "=r,";
486  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
487  ss << "r,";
488  // The final operand is for the sparsity metadata.
489  // The sparsity selector appears as direct literal.
490  ss << "r";
491  return str;
492 }
493 
494 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
495 /// the given parameters. Note that this function doesn't do any validation,
496 /// it's expected that the provided parameters correspond to a valid
497 /// instruction.
498 static std::string buildMmaSparseAsmString(
499  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
500  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
501  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
502  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
503  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
504  return NVVM::stringifyMMATypes(ptxType);
505  };
506 
507  std::string asmStr;
508  llvm::raw_string_ostream ss(asmStr);
509  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
510  << shape[2] << ".row.col.";
511 
512  if (overflow)
513  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
514 
515  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
516  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
517  unsigned asmArgIdx = 0;
518 
519  // The operand string is structured into sections `{matC elements...},
520  // {matA elements...}, {matB elements...}, {matC elements}`.
521  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
522  ss << "{";
523  for (unsigned i = 0; i < arrSize; i++)
524  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
525  ss << "},";
526  }
527  ss << "$" << asmArgIdx++ << ",";
528  assert(metaDataSelector <= 1);
529  ss << "0x" << metaDataSelector << ";";
530  return asmStr;
531 }
532 
533 /// Builds an inline assembly operation corresponding to the specified MMA
534 /// sparse sync operation.
535 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
536  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
537  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
538  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
539  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
540  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
541  Type intrinsicResultType) {
542  auto asmDialectAttr =
543  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
544 
545  const unsigned matASize = unpackedAData.size();
546  const unsigned matBSize = unpackedB.size();
547  const unsigned matCSize = unpackedC.size();
548 
549  std::string asmStr = buildMmaSparseAsmString(
550  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
551  ptxTypeD, overflow, metadataSelector);
552  std::string constraintStr =
553  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
554 
555  SmallVector<Value> asmVals;
556  asmVals.reserve(matASize + matBSize + matCSize + 1);
557  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
558  llvm::append_range(asmVals, args);
559  asmVals.push_back(indexData);
560 
561  return LLVM::InlineAsmOp::create(b,
562  /*resultTypes=*/intrinsicResultType,
563  /*operands=*/asmVals,
564  /*asm_string=*/asmStr,
565  /*constraints=*/constraintStr,
566  /*has_side_effects=*/true,
567  /*is_align_stack=*/false,
569  /*asm_dialect=*/asmDialectAttr,
570  /*operand_attrs=*/ArrayAttr());
571 }
572 
573 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
574 struct NVGPUMmaSparseSyncLowering
575  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
577 
578  LogicalResult
579  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
580  ConversionPatternRewriter &rewriter) const override {
581  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
582  // Get the shapes of the MMAMatrix type being used. The shapes will
583  // choose which intrinsic this op will be lowered to.
584  VectorType aType = op.getMatrixA().getType();
585  VectorType bType = op.getMatrixB().getType();
586  VectorType cType = op.getMatrixC().getType();
587 
588  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
589  if (failed(ptxTypeA))
590  return op->emitOpError("failed to deduce operand PTX types");
591  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
592  if (failed(ptxTypeB))
593  return op->emitOpError("failed to deduce operand PTX types");
594  std::optional<NVVM::MMATypes> ptxTypeC =
595  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
596  /*isAccumulator=*/true);
597  if (!ptxTypeC)
598  return op->emitError(
599  "could not infer the PTX type for the accumulator/result");
600 
601  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
602  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
603  if (aType.getElementType().isF32() && !tf32Enabled)
604  return failure();
605 
606  // TODO: add an attribute to the op to customize this behavior.
607  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
608  if (isa<IntegerType>(aType.getElementType()))
609  overflow = NVVM::MMAIntOverflow::satfinite;
610 
611  SmallVector<Value> matA =
612  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
613  SmallVector<Value> matB =
614  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
615  SmallVector<Value> matC =
616  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
617 
618  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
619  Type intrinsicResTy = inferIntrinsicResultType(
620  typeConverter->convertType(op->getResultTypes()[0]));
621 
622  // Bitcast the sparse metadata from vector<2xf16> to an i32.
623  Value sparseMetadata = adaptor.getSparseMetadata();
624  if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
625  return op->emitOpError() << "Expected metadata type to be LLVM "
626  "VectorType of 2 i16 elements";
627  sparseMetadata =
628  LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata);
629 
630  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
631  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
632  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
633  intrinsicResTy);
634  if (failed(intrinsicResult))
635  return failure();
636 
637  assert((*intrinsicResult).getNumResults() == 1 &&
638  "expected inline asm op returns a single LLVM struct type");
639  rewriter.replaceOp(
640  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
641  (*intrinsicResult)->getResult(0), rewriter));
642  return success();
643  }
644 };
645 
646 struct NVGPUAsyncCopyLowering
647  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
649  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
650 
651  LogicalResult
652  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
653  ConversionPatternRewriter &rewriter) const override {
654  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
655  Location loc = op.getLoc();
656  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
657  Value dstPtr =
658  getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
659  adaptor.getDst(), adaptor.getDstIndices());
660  FailureOr<unsigned> dstAddressSpace =
661  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
662  if (failed(dstAddressSpace))
663  return rewriter.notifyMatchFailure(
664  loc, "destination memref address space not convertible to integer");
665 
666  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
667  FailureOr<unsigned> srcAddressSpace =
668  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
669  if (failed(srcAddressSpace))
670  return rewriter.notifyMatchFailure(
671  loc, "source memref address space not convertible to integer");
672 
673  Value scrPtr =
674  getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
675  adaptor.getSrcIndices());
676  // Intrinsics takes a global pointer so we need an address space cast.
677  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
678  op->getContext(), static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
679  scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
680  int64_t dstElements = adaptor.getDstElements().getZExtValue();
681  int64_t sizeInBytes =
682  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
683  // When the optional SrcElements argument is *not* present, the regular
684  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
685  // memory) to fill DstElements number of elements in the destination
686  // (shared memory).
687  Value srcBytes = adaptor.getSrcElements();
688  if (srcBytes) {
689  // When the optional SrcElements argument is present, the source (global
690  // memory) of CpAsyncOp is read only for SrcElements number of elements.
691  // The rest of the DstElements in the destination (shared memory) are
692  // filled with zeros.
693  Value c3I32 =
694  LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3));
695  Value bitwidth = LLVM::ConstantOp::create(
696  b, b.getI32Type(),
697  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
698  Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes);
699  srcBytes = LLVM::LShrOp::create(
700  b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
701  }
702  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
703  // 16 dst bytes.
704  NVVM::LoadCacheModifierKind cacheModifier =
705  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
706  ? NVVM::LoadCacheModifierKind::CG
707  : NVVM::LoadCacheModifierKind::CA;
708 
709  NVVM::CpAsyncOp::create(
710  b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
711  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
712  srcBytes);
713 
714  // Drop the result token.
715  Value zero =
716  LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32),
717  rewriter.getI32IntegerAttr(0));
718  rewriter.replaceOp(op, zero);
719  return success();
720  }
721 };
722 
723 struct NVGPUAsyncCreateGroupLowering
724  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
726  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
727 
728  LogicalResult
729  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
730  ConversionPatternRewriter &rewriter) const override {
731  NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
732  // Drop the result token.
733  Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
734  IntegerType::get(op.getContext(), 32),
735  rewriter.getI32IntegerAttr(0));
736  rewriter.replaceOp(op, zero);
737  return success();
738  }
739 };
740 
741 struct NVGPUAsyncWaitLowering
742  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
744  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
745 
746  LogicalResult
747  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
748  ConversionPatternRewriter &rewriter) const override {
749  // If numGroup is not present pick 0 as a conservative correct value.
750  int32_t numGroups = adaptor.getNumGroups().value_or(0);
751  NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
752  rewriter.eraseOp(op);
753  return success();
754  }
755 };
756 
757 /// Creates mbarrier object in shared memory
758 struct NVGPUMBarrierCreateLowering
759  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
761 
762  template <typename moduleT>
763  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
764  Operation *funcOp, moduleT moduleOp,
765  MemRefType barrierType) const {
766  SymbolTable symbolTable(moduleOp);
767  OpBuilder::InsertionGuard guard(rewriter);
768  rewriter.setInsertionPoint(&moduleOp.front());
769  auto global = memref::GlobalOp::create(
770  rewriter, funcOp->getLoc(), "__mbarrier",
771  /*sym_visibility=*/rewriter.getStringAttr("private"),
772  /*type=*/barrierType,
773  /*initial_value=*/ElementsAttr(),
774  /*constant=*/false,
775  /*alignment=*/rewriter.getI64IntegerAttr(8));
776  symbolTable.insert(global);
777  return global;
778  }
779 
780  LogicalResult
781  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
782  ConversionPatternRewriter &rewriter) const override {
783  Operation *funcOp = op->getParentOp();
784  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
785  rewriter.getContext(), op.getBarriers().getType());
786 
787  memref::GlobalOp global;
788  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
789  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
790  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
791  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
792 
793  rewriter.setInsertionPoint(op);
794  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
795  global.getName());
796  return success();
797  }
798 };
799 
800 /// Base class for lowering mbarrier operations to nvvm intrinsics.
801 template <typename SourceOp>
802 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
803 public:
805  /// Returns the base pointer of the mbarrier object.
806  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
807  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
808  Value mbarId,
809  ConversionPatternRewriter &rewriter) const {
810  MemRefType mbarrierMemrefType =
811  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
813  rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
814  }
815 };
816 
817 struct NVGPUMBarrierGetLowering
818  : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
819  using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
820 
821  LogicalResult
822  matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
823  ConversionPatternRewriter &rewriter) const override {
824  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
825  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
826  rewriter.setInsertionPoint(op);
827  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
828  adaptor.getMbarId(), rewriter);
829  Type resType = op.getMbarrierPointer().getType();
830  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
831  return success();
832  }
833 };
834 
835 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
836 struct NVGPUMBarrierInitLowering
837  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
838  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
839 
840  LogicalResult
841  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
842  ConversionPatternRewriter &rewriter) const override {
843  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
844  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
845  rewriter.setInsertionPoint(op);
846  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
847  adaptor.getMbarId(), rewriter);
848  Value count = truncToI32(b, adaptor.getCount());
849  if (isMbarrierShared(mbarrierType)) {
850  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
851  op, barrier, count, adaptor.getPredicate());
852  } else {
853  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
854  adaptor.getPredicate());
855  }
856  return success();
857  }
858 };
859 
860 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
861 struct NVGPUMBarrierArriveLowering
862  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
863  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
864  LogicalResult
865  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
866  ConversionPatternRewriter &rewriter) const override {
867  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
868  Value barrier =
869  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
870  adaptor.getMbarId(), rewriter);
871  Type tokenType = getTypeConverter()->convertType(
872  nvgpu::MBarrierTokenType::get(op->getContext()));
873  if (isMbarrierShared(op.getBarriers().getType())) {
874  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
875  barrier);
876  } else {
877  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
878  barrier);
879  }
880  return success();
881  }
882 };
883 
884 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
885 /// `nvvm.mbarrier.arrive.nocomplete`
886 struct NVGPUMBarrierArriveNoCompleteLowering
887  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
888  using MBarrierBasePattern<
889  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
890  LogicalResult
891  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
892  ConversionPatternRewriter &rewriter) const override {
893  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
894  Value barrier =
895  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
896  adaptor.getMbarId(), rewriter);
897  Type tokenType = getTypeConverter()->convertType(
898  nvgpu::MBarrierTokenType::get(op->getContext()));
899  Value count = truncToI32(b, adaptor.getCount());
900  if (isMbarrierShared(op.getBarriers().getType())) {
901  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
902  op, tokenType, barrier, count);
903  } else {
904  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
905  op, tokenType, barrier, count);
906  }
907  return success();
908  }
909 };
910 
911 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
912 struct NVGPUMBarrierTestWaitLowering
913  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
914  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
915  LogicalResult
916  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
917  ConversionPatternRewriter &rewriter) const override {
918  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
919  Value barrier =
920  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
921  adaptor.getMbarId(), rewriter);
922  Type retType = rewriter.getI1Type();
923  if (isMbarrierShared(op.getBarriers().getType())) {
924  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
925  op, retType, barrier, adaptor.getToken());
926  } else {
927  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
928  op, retType, barrier, adaptor.getToken());
929  }
930  return success();
931  }
932 };
933 
934 struct NVGPUMBarrierArriveExpectTxLowering
935  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
936  using MBarrierBasePattern<
937  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
938  LogicalResult
939  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
940  ConversionPatternRewriter &rewriter) const override {
941  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
942  Value barrier =
943  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
944  adaptor.getMbarId(), rewriter);
945  Value txcount = truncToI32(b, adaptor.getTxcount());
946 
947  if (isMbarrierShared(op.getBarriers().getType())) {
948  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
949  op, barrier, txcount, adaptor.getPredicate());
950  return success();
951  }
952 
953  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
954  op, barrier, txcount, adaptor.getPredicate());
955  return success();
956  }
957 };
958 
959 struct NVGPUMBarrierTryWaitParityLowering
960  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
961  using MBarrierBasePattern<
962  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
963  LogicalResult
964  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
965  ConversionPatternRewriter &rewriter) const override {
966  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
967  Value barrier =
968  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
969  adaptor.getMbarId(), rewriter);
970  Value ticks = truncToI32(b, adaptor.getTicks());
971  Value phase =
972  LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
973 
974  if (isMbarrierShared(op.getBarriers().getType())) {
975  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
976  op, barrier, phase, ticks);
977  return success();
978  }
979 
980  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
981  phase, ticks);
982  return success();
983  }
984 };
985 
986 struct NVGPUTmaAsyncLoadOpLowering
987  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
988  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
989  LogicalResult
990  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
991  ConversionPatternRewriter &rewriter) const override {
992  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
993  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
994  Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
995  adaptor.getDst(), {});
996  // Intrinsics takes a shared-cluster pointer so we need an
997  // address space cast from 3 to 7.
998  // TODO: Introduce AS(7) in NVGPU.
999  auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
1000  op->getContext(),
1001  static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
1002  dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);
1003 
1004  Value barrier =
1005  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1006  adaptor.getMbarId(), rewriter);
1007 
1008  SmallVector<Value> coords = adaptor.getCoordinates();
1009  for (auto [index, value] : llvm::enumerate(coords)) {
1010  coords[index] = truncToI32(b, value);
1011  }
1012 
1013  // TODO: Enhance the NVGPU Op for other modes too
1014  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
1015  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1016  ValueRange{}, adaptor.getMulticastMask(), Value{},
1017  NVVM::TMALoadMode::TILE, // default is TILE mode
1018  false, // default is cluster-scope
1019  nullptr, // default is no cta-group
1020  adaptor.getPredicate());
1021  return success();
1022  }
1023 };
1024 
1025 struct NVGPUTmaAsyncStoreOpLowering
1026  : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1027  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1028  LogicalResult
1029  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1030  ConversionPatternRewriter &rewriter) const override {
1031  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1032  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1033  Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1034  adaptor.getSrc(), {});
1035  SmallVector<Value> coords = adaptor.getCoordinates();
1036  for (auto [index, value] : llvm::enumerate(coords)) {
1037  coords[index] = truncToI32(b, value);
1038  }
1039 
1040  // TODO: Enhance the NVGPU Op for other modes too
1041  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1042  op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1043  NVVM::TMAStoreMode::TILE, // default is TILE mode
1044  adaptor.getPredicate());
1045  return success();
1046  }
1047 };
1048 
1049 struct NVGPUGenerateWarpgroupDescriptorLowering
1050  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1051  using ConvertOpToLLVMPattern<
1052  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1053 
1054  LogicalResult
1055  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1056  ConversionPatternRewriter &rewriter) const override {
1057 
1058  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1059 
1060  nvgpu::TensorMapSwizzleKind swizzleKind =
1061  op.getTensorMap().getType().getSwizzle();
1062 
1063  unsigned layout =
1064  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1065  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1066  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1067  : 1;
1068  unsigned swizzle =
1069  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1070  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1071  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1072  : 0;
1073 
1074  auto ti64 = b.getIntegerType(64);
1075  auto makeConst = [&](uint64_t index) -> Value {
1076  return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index));
1077  };
1078  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1079  return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1080  };
1081  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1082  return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1083  };
1084  auto insertBit = [&](Value desc, Value val, int startBit) {
1085  return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1086  };
1087 
1088  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1089  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1090  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1091  uint64_t offsetVal = 0;
1092 
1093  Value strideDim = makeConst(strideDimVal);
1094  Value leadDim = makeConst(leadDimVal);
1095 
1096  Value baseAddr = getStridedElementPtr(
1097  rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1098  adaptor.getTensor(), {});
1099  Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1100  // Just use 14 bits for base address
1101  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1102 
1103  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1104  startLeadBit = 16, startBaseAddrBit = 0;
1105  Value dsc = makeConst(0);
1106  // // [62,64) swizzle type
1107  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1108  // // [49,52) base_offset
1109  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1110  // // [32,46) stride
1111  dsc = insertBit(dsc, strideDim, startStrideBit);
1112  // // [16,30) leading dimension
1113  dsc = insertBit(dsc, leadDim, startLeadBit);
1114  // // [0,14) start_address
1115  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1116 
1117  LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
1118  << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
1119  << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
1120  << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1121  << ")\n start_addr : " << baseAddr;
1122 
1123  rewriter.replaceOp(op, dsc);
1124  return success();
1125  }
1126 };
1127 
1128 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1129  return LLVM::ConstantOp::create(b, b.getIntegerType(64),
1130  b.getI32IntegerAttr(index));
1131 }
1132 
1133 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1134 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1135  // Enum is from CUDA driver API
1136  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1137  enum CUtensorMapDataTypeEnum {
1138  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1139  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1140  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1141  CU_TENSOR_MAP_DATA_TYPE_INT32,
1142  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1143  CU_TENSOR_MAP_DATA_TYPE_INT64,
1144  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1145  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1146  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1147  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1148  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1149  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1150  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1151  };
1152 
1153  if (type.isUnsignedInteger(8))
1154  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1155  if (type.isUnsignedInteger(16))
1156  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1157  if (type.isUnsignedInteger(32))
1158  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1159  if (type.isUnsignedInteger(64))
1160  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1161  if (type.isSignlessInteger(32))
1162  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1163  if (type.isSignlessInteger(64))
1164  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1165  if (type.isF16())
1166  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1167  if (type.isF32())
1168  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1169  if (type.isF64())
1170  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1171  if (type.isBF16())
1172  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1173 
1174  llvm_unreachable("Not supported data type");
1175 }
1176 
1177 struct NVGPUTmaCreateDescriptorOpLowering
1178  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1179  using ConvertOpToLLVMPattern<
1180  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1181  LogicalResult
1182  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1183  ConversionPatternRewriter &rewriter) const override {
1184  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1185  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1186  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1187 
1188  Value tensorElementType =
1189  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1190  auto promotedOperands = getTypeConverter()->promoteOperands(
1191  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1192 
1193  Value boxArrayPtr = LLVM::AllocaOp::create(
1194  b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1195  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1196  Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1197  boxArrayPtr, makeI64Const(b, index));
1198  LLVM::StoreOp::create(b, value, gep);
1199  }
1200 
1201  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1202  // Set Arguments for the function call
1203  SmallVector<Value> arguments;
1204  arguments.push_back(promotedOperands[0]); // rank
1205  arguments.push_back(promotedOperands[1]); // descriptor
1206  arguments.push_back(tensorElementType); // data type
1207  arguments.push_back(
1208  makeI64Const(b, (int)desc.getInterleave())); // interleave
1209  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1210  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1211  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1212  arguments.push_back(boxArrayPtr); // box dimensions
1213 
1214  // Set data types of the arguments
1215  SmallVector<Type> argTypes = {
1216  llvmInt64Type, /* int64_t tensorRank */
1217  llvmPointerType, /* ptr */
1218  llvmInt64Type, /* int64_t */
1219  llvmInt64Type, /* int64_t */
1220  llvmInt64Type, /* int64_t */
1221  llvmInt64Type, /* int64_t */
1222  llvmInt64Type, /* int64_t */
1223  llvmPointerType /* ptr */
1224  };
1225  FunctionCallBuilder hostRegisterCallBuilder = {
1226  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1227  Value tensorMap =
1228  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1229 
1230  rewriter.replaceOp(op, tensorMap);
1231  return success();
1232  }
1233 };
1234 
1235 struct NVGPUWarpgroupMmaOpLowering
1236  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1238 
1239  /// This is a helper class to generate required NVVM Ops for warp-group level
1240  /// matrix multiplication.
1241  /// When the given GEMM shape is larger than the shape of
1242  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1243  /// Op(s), group and execute them asynchronously. The class also handles
1244  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1245  /// create descriptors for each instruction.
1246  ///
1247  /// For example this is the case when the shape of GEMM is 128x128x128
1248  ///
1249  /// nvvm.wgmma.fence.aligned
1250  ///
1251  /// nvvm.wgmma.mma.async descA, descB
1252  /// iterate(descA, descB)
1253  /// nvvm.wgmma.mma.async descA, descB
1254  /// [6x times more]
1255  ///
1256  /// nvvm.wgmma.group.sync.aligned
1257  /// nvvm.wgmma.wait.group.sync [groupId]
1258  ///
1259  class WarpgroupGemm {
1260  nvgpu::WarpgroupMmaOp op;
1262  OpAdaptor adaptor;
1263 
1264  // Entire shape of the given Op
1265  int64_t totalM, totalN, totalK;
1266 
1267  // Shape of one wgmma instruction
1268  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1269 
1270  // Iteration counts for GEMM
1271  int iterationM = 0, iterationN = 0, iterationK = 0;
1272 
1273  /// The function returns the shape of wgmma instruction that is defined in
1274  /// PTX programming guide.
1275  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1276  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1277  wgmmaM = 64;
1278  wgmmaN = sizeN;
1279  if (inputElemType.isTF32()) {
1280  wgmmaK = 8;
1281  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1282  wgmmaK = 16;
1283  } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1284  inputElemType.isInteger(16)) {
1285  wgmmaK = 32;
1286  } else if (inputElemType.isInteger(1)) {
1287  wgmmaK = 256;
1288  } else {
1289  llvm_unreachable("msg: not supported K shape");
1290  }
1291  LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1292  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
1293  }
1294 
1295  /// Generates WGMMATypesAttr from MLIR Type
1296  NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1297  bool useF32 = false) const {
1298  auto getWgmmaType = [=](Type elemType) {
1299  if (elemType.isF32() || elemType.isTF32())
1300  return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1301  if (elemType.isF16())
1302  return NVVM::WGMMATypes::f16;
1303  if (elemType.isBF16())
1304  return NVVM::WGMMATypes::bf16;
1305  if (isa<Float8E4M3FNType>(elemType))
1306  return NVVM::WGMMATypes::e4m3;
1307  if (isa<Float8E5M2Type>(elemType))
1308  return NVVM::WGMMATypes::e5m2;
1309  if (elemType.isInteger(1))
1310  return NVVM::WGMMATypes::b1;
1311  if (elemType.isInteger(8))
1312  return NVVM::WGMMATypes::s8;
1313  if (elemType.isUnsignedInteger(8))
1314  return NVVM::WGMMATypes::u8;
1315  if (elemType.isInteger(32))
1316  return NVVM::WGMMATypes::s32;
1317  llvm_unreachable("unsupported type");
1318  };
1319  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1320  }
1321 
1322  /// Generates layout attribute for the input matrix for wgmma instruction
1323  NVVM::MMALayoutAttr
1324  generateWgmmaLayout(std::optional<bool> transpose) const {
1325  if (transpose.value_or(false))
1326  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1327  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1328  }
1329 
1330  /// Generates shape attribute for wgmma instruction
1331  NVVM::MMAShapeAttr generateWgmmaShape() const {
1332  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1333  }
1334 
1335  /// Generates scale attributes of output matrix for wgmma instruction
1336  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1337  return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1338  NVVM::WGMMAScaleOut::one);
1339  }
1340  /// Generates scale attributes of input matrix for wgmma instruction
1341  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1342  return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1343  NVVM::WGMMAScaleIn::one);
1344  }
1345 
1346  /// Basic function to generate Add
1347  Value makeAdd(Value lhs, Value rhs) {
1348  return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1349  };
1350 
1351  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1352  /// Currently, it only handles row-major.
1353  ///
1354  /// It moves the pointer like below for [128][64] size:
1355  /// +2 +4 +6
1356  /// ↓ ↓ ↓
1357  /// descA ---> +--+--+--+--+
1358  /// |->|->|->|->|
1359  /// | | | | |
1360  /// | | | | |
1361  /// | | | | |
1362  /// descA+512---> +-----------+
1363  /// | | | | |
1364  /// | | | | |
1365  /// | | | | |
1366  /// | | | | |
1367  /// +-----------+
1368  ///
1369  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1370  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1371  Type elemA = matrixTypeA.getElementType();
1372  int byte = elemA.getIntOrFloatBitWidth() / 8;
1373  int tileShapeA = matrixTypeA.getDimSize(1);
1374  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1375  incrementVal = incrementVal >> exclude4LSB;
1376  LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
1377  << "] [wgmma descriptors] Descriptor A + " << incrementVal
1378  << " | \t ";
1379  if (!incrementVal)
1380  return desc;
1381  return makeAdd(desc, makeI64Const(b, incrementVal));
1382  }
1383 
1384  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1385  /// Currently, it only handles column-major.
1386  ///
1387  /// It moves the pointer like below for [128][64] size:
1388  /// descB ---> +--+--+--+--+--+--+--+--+
1389  /// |↓ | | | | | | | |
1390  /// |↓ | | | | | | | |
1391  /// |↓ | | | | | | | |
1392  /// |↓ | | | | | | | |
1393  /// +--+--+--+--+--+--+--+--+
1394  ///
1395  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1396  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1397  Type elemB = matrixTypeB.getElementType();
1398  int byte = elemB.getIntOrFloatBitWidth() / 8;
1399  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1400  incrementVal = incrementVal >> exclude4LSB;
1401  LDBG() << "Descriptor B + " << incrementVal;
1402  if (!incrementVal)
1403  return desc;
1404  return makeAdd(desc, makeI64Const(b, incrementVal));
1405  }
1406 
1407  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1408  /// descriptors and arranges them based on induction variables: i, j, and k.
1409  Value generateWgmma(int i, int j, int k, Value matrixC) {
1410  LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1411  << "(A[" << (iterationM * wgmmaM) << ":"
1412  << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
1413  << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
1414  << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
1415  << "][" << 0 << ":" << wgmmaN << "])";
1416 
1417  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1418  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1419 
1420  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1421  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1422 
1423  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1424  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1425 
1426  Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1427  NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1428 
1429  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1430  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1431  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1432  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1433  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1434 
1435  auto overflow = NVVM::MMAIntOverflowAttr::get(
1436  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1437 
1438  return NVVM::WgmmaMmaAsyncOp::create(
1439  b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape,
1440  itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1441  overflow);
1442  }
1443 
1444  /// Generates multiple wgmma instructions to complete the given GEMM shape
1445  Value generateWgmmaGroup() {
1446  Value wgmmaResult =
1447  LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1448 
1449  // Perform GEMM
1450  SmallVector<Value> wgmmaResults;
1451  for (int i = 0; i < iterationM; ++i) {
1452  Value matrixC =
1453  LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1454  for (int j = 0; j < iterationN; ++j)
1455  for (int k = 0; k < iterationK; ++k)
1456  matrixC = generateWgmma(i, j, k, matrixC);
1457  wgmmaResults.push_back(matrixC);
1458  }
1459  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1460  wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(),
1461  wgmmaResult, matrix, idx);
1462  }
1463  return wgmmaResult;
1464  }
1465 
1466  public:
1467  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1468  OpAdaptor adaptor)
1469  : op(op), b(b), adaptor(adaptor) {
1470  // Find the entire GEMM Shape
1471  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1472  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1473  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1474  LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
1475  << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
1476  << "] ---===";
1477 
1478  // Find the shape for one wgmma instruction
1479  findWgmmaShape(
1480  totalM, totalN,
1481  op.getDescriptorA().getType().getTensor().getElementType());
1482 
1483  // Iterations counts to complete the given shape with wgmma shape
1484  iterationM = totalM / wgmmaM;
1485  iterationN = totalN / wgmmaN;
1486  iterationK = totalK / wgmmaK;
1487  }
1488 
1489  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1490  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1491  /// instructions and group synchronization, as well as waiting
1492  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1493  /// (WgmmaWaitGroupSyncOp) after the instructions.
1494  Value generateWarpgroupMma() {
1495  NVVM::WgmmaFenceAlignedOp::create(b);
1496  Value wgmmaResult = generateWgmmaGroup();
1497  NVVM::WgmmaGroupSyncAlignedOp::create(b);
1498  NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1499  return wgmmaResult;
1500  }
1501  };
1502  LogicalResult
1503  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1504  ConversionPatternRewriter &rewriter) const override {
1505  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1506 
1507  // Step 1. Build a helper class
1508  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1509 
1510  // Step 2. Get the entire GEMM Shape
1511  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1512 
1513  // Step 3. Replace fragmented result struct with the op results
1514  rewriter.replaceOp(op, wgmmaResult);
1515  return success();
1516  }
1517 };
1518 
1519 struct NVGPUWarpgroupMmaStoreOpLowering
1520  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1521  using ConvertOpToLLVMPattern<
1522  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1523 
1524  /// This function stores a fragmented register matrix owned by a warp group
1525  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1526  /// of a struct.
1527  /// Here is what each threads (T) holds, each `d` is struct value with a
1528  /// number.
1529  ///
1530  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1531  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1532  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1533  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1534  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1535  ///
1536  /// Matrix-D:
1537  /// +______________________________________________________________________+
1538  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1539  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1540  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1541  /// ..| .........|.........|.........|.........|........|...........|........|
1542  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1543  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1544  /// ..| .........|.........|.........|.........|........|...........|........|
1545  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1546  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1547  /// ..| .........|.........|.........|.........|........|...........|........|
1548  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1549  /// ..| .........|.........|.........|.........|........|...........|........|
1550  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1551  /// ..| .........|.........|.........|.........|........|...........|........|
1552  /// +______________________________________________________________________+
1553  ///
1554  /// \param rewriter: The pattern rewriter.
1555  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1556  /// matrix). It is holded by a thread and a struct with 64 elements.
1557  /// \param dstMemref: The memref where the registers will be stored.
1558  /// \param offset: the offset within the memref where the registers will be
1559  /// stored.
1560  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1561  TypedValue<MemRefType> dstMemref,
1562  int offset) const {
1563  Type i32 = b.getI32Type();
1564 
1565  auto makeConst = [&](int32_t index) -> Value {
1566  return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index));
1567  };
1568  Value c1 = makeConst(1);
1569  Value c2 = makeConst(2);
1570  Value c4 = makeConst(4);
1571  Value c8 = makeConst(8);
1572  Value c16 = makeConst(16);
1573  Value warpSize = makeConst(kWarpSize);
1574 
1575  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1576  return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs);
1577  };
1578  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1579  return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs);
1580  };
1581 
1582  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1584  Type it = b.getIndexType();
1585  Value idx = arith::IndexCastOp::create(b, it, x);
1586  Value idy0 = arith::IndexCastOp::create(b, it, y);
1587  Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1588  Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1589  Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1590  memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0});
1591  memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1});
1592  };
1593 
1594  Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1595  Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1596  Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1597  Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1598  Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1599 
1600  Value tj = makeMul(lane4modId, c2);
1601  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1602  if (offset)
1603  ti = makeAdd(ti, makeConst(offset));
1604 
1605  auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1606 
1607  // Number of 32-bit registers owns per thread
1608  constexpr unsigned numAdjacentRegisters = 2;
1609  // Number of 8x8 matrices one below another per warp
1610  constexpr unsigned numStackedMatrices = 2;
1611 
1612  size_t storeCount = (structType.getBody().size() /
1613  (numStackedMatrices * numAdjacentRegisters));
1614 
1615  for (size_t i = 0; i < numStackedMatrices; ++i) {
1616  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1617  for (size_t j = 0; j < storeCount; ++j) {
1618  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1619  size_t structIndex = (i * numAdjacentRegisters) +
1620  (j * (numStackedMatrices * numAdjacentRegisters));
1621  makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1622  }
1623  }
1624  }
1625 
1626  LogicalResult
1627  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1628  ConversionPatternRewriter &rewriter) const override {
1629  int offset = 0;
1630  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1631  Value matriDValue = adaptor.getMatrixD();
1632  auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1633  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1634  auto structType = cast<LLVM::LLVMStructType>(matrixD);
1635  Value innerStructValue =
1636  LLVM::ExtractValueOp::create(b, matriDValue, idx);
1637  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1638  offset += structType.getBody().size();
1639  }
1640  rewriter.eraseOp(op);
1641  return success();
1642  }
1643 };
1644 
1645 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1646  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1647  using ConvertOpToLLVMPattern<
1648  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1649  LogicalResult
1650  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1651  ConversionPatternRewriter &rewriter) const override {
1652  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1653  LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1654  getTypeConverter()->convertType(op.getMatrixC().getType()));
1655  Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1656  .getBody()
1657  .front();
1658  Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType));
1659  Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1660  SmallVector<Value> innerStructs;
1661  // Unpack the structs and set all values to zero
1662  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1663  auto structType = cast<LLVM::LLVMStructType>(s);
1664  Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1665  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1666  structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1667  zero, ArrayRef<int64_t>({i}));
1668  }
1669  innerStructs.push_back(structValue);
1670  }
1671  // Pack the inner structs into a single struct
1672  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1673  packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(),
1674  packStruct, matrix, idx);
1675  }
1676  rewriter.replaceOp(op, packStruct);
1677  return success();
1678  }
1679 };
1680 
1681 struct NVGPUTmaFenceOpLowering
1682  : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1684  LogicalResult
1685  matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1686  ConversionPatternRewriter &rewriter) const override {
1687  MLIRContext *ctx = op.getContext();
1688  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1689  auto i32Ty = b.getI32Type();
1690  Value tensormapSize =
1691  LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128));
1692 
1693  auto memscope =
1694  NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1695 
1696  rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1697  op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1698 
1699  return success();
1700  }
1701 };
1702 
1703 struct NVGPUTmaPrefetchOpLowering
1704  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1706  LogicalResult
1707  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1708  ConversionPatternRewriter &rewriter) const override {
1709  rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1710  op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
1711  adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1712  /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
1713  return success();
1714  }
1715 };
1716 
1717 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1719  LogicalResult
1720  matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1721  ConversionPatternRewriter &rewriter) const override {
1722  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1723  auto i64Ty = b.getI64Type();
1724  auto f32Ty = b.getF32Type();
1725  VectorType inTy = op.getIn().getType();
1726  // apply rcp.approx.ftz.f on each element in vector.
1727  auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1728  Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1729  int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1730  for (int i = 0; i < numElems; i++) {
1731  Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i));
1732  Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1733  Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1734  ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1735  }
1736  return ret1DVec;
1737  };
1738  if (inTy.getRank() == 1) {
1739  rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1740  return success();
1741  }
1743  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1744  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1745  OpAdaptor adaptor(operands);
1746  return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1747  },
1748  rewriter);
1749  }
1750 };
1751 } // namespace
1752 
1754  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1755  patterns.add<
1756  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1757  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1758  NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1759  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1760  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1761  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1762  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1763  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1764  NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1765  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1766  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1767  NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1768  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1769  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1770  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1771  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1772  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1773  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1774  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1775  NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1776 }
static MLIRContext * getContext(OpFoldResult val)
@ None
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:49
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:59
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:45
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
Definition: NVGPUToNVVM.cpp:98
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:60
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
FloatType getF16Type()
Definition: Builders.cpp:38
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
IntegerType getI8Type()
Definition: Builders.cpp:58
FloatType getF64Type()
Definition: Builders.cpp:44
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition: Pattern.cpp:64
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:656
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isTF32() const
Definition: Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.