MLIR  21.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
37 
38 namespace mlir {
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 /// Number of bits that needs to be excluded when building matrix descriptor for
46 /// wgmma operations.
47 constexpr int exclude4LSB = 4;
48 
49 /// GPU has 32 bit registers, this function truncates values when larger width
50 /// is not needed.
52  Type type = value.getType();
53  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
54  if (type.getIntOrFloatBitWidth() <= 32)
55  return value;
56  return b.create<LLVM::TruncOp>(b.getI32Type(), value);
57 }
58 
59 /// Returns the type for the intrinsic given the vectorResultType of the
60 /// `gpu.mma.sync` operation.
61 static Type inferIntrinsicResultType(Type vectorResultType) {
62  MLIRContext *ctx = vectorResultType.getContext();
63  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
64  auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
65  auto i32Ty = IntegerType::get(ctx, 32);
66  auto i32x2Ty = VectorType::get(2, i32Ty);
67  Type f64Ty = Float64Type::get(ctx);
68  Type f64x2Ty = VectorType::get(2, f64Ty);
69  Type f32Ty = Float32Type::get(ctx);
70  Type f32x2Ty = VectorType::get(2, f32Ty);
71  if (a.getElementType() == f16x2Ty) {
72  return LLVM::LLVMStructType::getLiteral(
73  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
74  }
75  if (a.getElementType() == i32x2Ty) {
76  return LLVM::LLVMStructType::getLiteral(
77  ctx,
78  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
79  }
80  if (a.getElementType() == f64x2Ty) {
81  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
82  }
83  if (a.getElementType() == f32x2Ty) {
84  return LLVM::LLVMStructType::getLiteral(
85  ctx,
86  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
87  }
88  if (a.getElementType() == VectorType::get(1, f32Ty)) {
89  return LLVM::LLVMStructType::getLiteral(
90  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
91  }
92  return vectorResultType;
93 }
94 
95 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
96 /// always an LLVM struct) into a fragment that is compatible with the vector
97 /// type of this operation. This involves extracting elements from the struct
98 /// and inserting them into an LLVM array. These extra data-movement
99 /// operations should be canonicalized away by the LLVM backend.
100 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
101  Type resultType, Value intrinsicResult,
102  RewriterBase &rewriter) {
103  MLIRContext *ctx = rewriter.getContext();
104  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
106  Type i32Ty = rewriter.getI32Type();
107  Type f32Ty = rewriter.getF32Type();
108  Type f64Ty = rewriter.getF64Type();
109  Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
110  Type i32x2Ty = VectorType::get(2, i32Ty);
111  Type f64x2Ty = VectorType::get(2, f64Ty);
112  Type f32x2Ty = VectorType::get(2, f32Ty);
113  Type f32x1Ty = VectorType::get(1, f32Ty);
114 
115  auto makeConst = [&](int32_t index) -> Value {
116  return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
117  rewriter.getI32IntegerAttr(index));
118  };
119 
120  if (arrayType) {
121  SmallVector<Value, 4> elements;
122 
123  // The intrinsic returns 32-bit wide elements in a form which can be
124  // directly bitcasted and inserted into the result vector.
125  if (arrayType.getElementType() == f16x2Ty ||
126  arrayType.getElementType() == f32x1Ty) {
127  for (unsigned i = 0; i < structType.getBody().size(); i++) {
128  Value el =
129  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
130  el = rewriter.createOrFold<LLVM::BitcastOp>(
131  loc, arrayType.getElementType(), el);
132  elements.push_back(el);
133  }
134  }
135 
136  // The intrinsic returns i32, f64, and f32 values as individual scalars,
137  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
138  // need to extract them from the struct and pack them into the 64-bit wide
139  // rows of the vector result.
140  if (arrayType.getElementType() == i32x2Ty ||
141  arrayType.getElementType() == f64x2Ty ||
142  arrayType.getElementType() == f32x2Ty) {
143 
144  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145  Value vec =
146  rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType());
147  Value x1 =
148  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149  Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
150  i * 2 + 1);
151  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
152  x1, makeConst(0));
153  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
154  x2, makeConst(1));
155  elements.push_back(vec);
156  }
157  }
158 
159  // Create the final vectorized result.
160  Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType);
161  for (const auto &el : llvm::enumerate(elements)) {
162  result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
163  el.index());
164  }
165  return result;
166  }
167 
168  return intrinsicResult;
169 }
170 
171 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
172 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
173 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
174 /// scalars of certain types. This function helps unpack the `vector` arguments
175 /// and cast them to the types expected by `nvvm.mma.sync`.
177  Value operand,
178  NVVM::MMATypes operandPtxType) {
179  SmallVector<Value> result;
180  Type i32Ty = b.getI32Type();
181  Type f64Ty = b.getF64Type();
182  Type f32Ty = b.getF32Type();
183  Type i64Ty = b.getI64Type();
184  Type i8x4Ty = VectorType::get(4, b.getI8Type());
185  Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
186  Type f32x1Ty = VectorType::get(1, f32Ty);
187  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
188 
189  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190  Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
191 
192  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
193  // scalar types.
194  if (arrayTy.getElementType() == i8x4Ty ||
195  arrayTy.getElementType() == i4x8Ty ||
196  (arrayTy.getElementType() == f32x1Ty &&
197  operandPtxType == NVVM::MMATypes::tf32)) {
198  result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
199  continue;
200  }
201 
202  // For some element types (i32, f32, f64), we need to unpack the inner
203  // vector/array type as well because the intrinsic expects individual
204  // scalars to be provided.
205  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207  innerArrayTy.getElementType() == f64Ty ||
208  innerArrayTy.getElementType() == f32Ty)) {
209  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210  idx < innerSize; idx++) {
211  result.push_back(b.create<LLVM::ExtractElementOp>(
212  toUse,
213  b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
214  }
215  continue;
216  }
217  result.push_back(toUse);
218  }
219  return result;
220 }
221 
222 /// Returns whether mbarrier object has shared memory address space.
223 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
224  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225  barrierType.getMemorySpace()));
226 }
227 
228 /// Returns the memory space attribute of the mbarrier object.
230  nvgpu::MBarrierGroupType barrierType) {
231  Attribute memorySpace = {};
232  if (isMbarrierShared(barrierType)) {
233  memorySpace =
234  IntegerAttr::get(IntegerType::get(context, 64),
235  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
236  }
237  return memorySpace;
238 }
239 
240 /// Returns memref type of the mbarrier object. The type is defined in the
241 /// MBarrierGroupType.
242 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
243  nvgpu::MBarrierGroupType barrierType) {
244  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
245  MemRefLayoutAttrInterface layout;
246  return MemRefType::get({barrierType.getNumBarriers()},
247  IntegerType::get(context, 64), layout, memorySpace);
248 }
249 
250 namespace {
251 
252 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
254 
255  LogicalResult
256  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
257  ConversionPatternRewriter &rewriter) const override {
258  MLIRContext *ctx = getContext();
259  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
260 
261  // The result type of ldmatrix will always be a struct of 32bit integer
262  // registers if more than one 32bit value is returned. Otherwise, the result
263  // is a single i32. The result type of the GPU operation is always a vector
264  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
265  // vector type of the result and always 32 bits long. We bitcast the result
266  // of the NVVM::LdMatrix to this vector type.
267  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268  if (!vectorResultType) {
269  return failure();
270  }
271  Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
272  vectorResultType.getElementType());
273 
274  int64_t num32BitRegs = vectorResultType.getDimSize(0);
275 
276  Type ldMatrixResultType;
277  if (num32BitRegs > 1) {
278  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
279  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
280  } else {
281  ldMatrixResultType = rewriter.getI32Type();
282  }
283 
284  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285  Value srcPtr =
286  getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287  adaptor.getIndices(), rewriter);
288  Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289  ldMatrixResultType, srcPtr,
290  /*num=*/op.getNumTiles(),
291  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
292  : NVVM::MMALayout::row);
293 
294  // The ldmatrix operation returns either a single i32 value or a struct of
295  // i32 values. Here we unpack those values and cast them back to their
296  // actual vector type (still of width 32b) and repack them into a result
297  // struct.
298  Type finalResultType = typeConverter->convertType(vectorResultType);
299  Value result = b.create<LLVM::PoisonOp>(finalResultType);
300  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301  Value i32Register =
302  num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303  : ldMatrixResult;
304  Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305  result = b.create<LLVM::InsertValueOp>(result, casted, i);
306  }
307 
308  rewriter.replaceOp(op, result);
309  return success();
310  }
311 };
312 
313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314 /// enum).
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316  Type elType = getElementTypeOrSelf(t);
317  if (elType.isInteger(8))
318  return NVVM::MMATypes::s8;
319  if (elType.isInteger(4))
320  return NVVM::MMATypes::s4;
321  if (elType.isF16())
322  return NVVM::MMATypes::f16;
323  if (elType.isF64())
324  return NVVM::MMATypes::f64;
325  if (elType.isF32())
326  return NVVM::MMATypes::tf32;
327  return failure();
328 }
329 
330 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
332 
333  LogicalResult
334  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override {
336  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337  // Get the shapes of the MMAMatrix type being used. The shapes will
338  // choose which intrinsic this op will be lowered to.
339  VectorType aType = op.getMatrixA().getType();
340  VectorType bType = op.getMatrixA().getType();
341  VectorType cType = op.getMatrixC().getType();
342 
343  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344 
345  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347  if (aType.getElementType().isF32() && !tf32Enabled)
348  return failure();
349 
350  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351  if (failed(ptxTypeA))
352  return op->emitOpError("failed to deduce operand PTX types");
353  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354  if (failed(ptxTypeB))
355  return op->emitOpError("failed to deduce operand PTX types");
356  std::optional<NVVM::MMATypes> ptxTypeC =
357  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358  /*isAccumulator=*/true);
359  if (!ptxTypeC)
360  return op->emitError(
361  "could not infer the PTX type for the accumulator/result");
362 
363  // TODO: add an attribute to the op to customize this behavior.
364  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365  if (isa<IntegerType>(aType.getElementType()))
366  overflow = NVVM::MMAIntOverflow::satfinite;
367 
368  SmallVector<Value> matA =
369  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370  SmallVector<Value> matB =
371  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372  SmallVector<Value> matC =
373  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374 
375  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376  Type intrinsicResTy = inferIntrinsicResultType(
377  typeConverter->convertType(op->getResultTypes()[0]));
378  Value intrinsicResult = b.create<NVVM::MmaOp>(
379  intrinsicResTy, matA, matB, matC,
380  /*shape=*/gemmShape,
381  /*b1Op=*/std::nullopt,
382  /*intOverflow=*/overflow,
383  /*multiplicandPtxTypes=*/
384  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385  /*multiplicandLayouts=*/
386  std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387  NVVM::MMALayout::col});
388  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389  desiredRetTy, intrinsicResult,
390  rewriter));
391  return success();
392  }
393 };
394 
395 struct ConvertNVGPUToNVVMPass
396  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397  using Base::Base;
398 
399  void getDependentDialects(DialectRegistry &registry) const override {
400  registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401  arith::ArithDialect>();
402  }
403 
404  void runOnOperation() override {
407  LLVMTypeConverter converter(&getContext(), options);
408  IRRewriter rewriter(&getContext());
410  converter, [](gpu::AddressSpace space) -> unsigned {
411  switch (space) {
412  case gpu::AddressSpace::Global:
413  return static_cast<unsigned>(
415  case gpu::AddressSpace::Workgroup:
416  return static_cast<unsigned>(
418  case gpu::AddressSpace::Private:
419  return 0;
420  }
421  llvm_unreachable("unknown address space enum value");
422  return 0;
423  });
424  /// device-side async tokens cannot be materialized in nvvm. We just
425  /// convert them to a dummy i32 type in order to easily drop them during
426  /// conversion.
427  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
428  return converter.convertType(IntegerType::get(type.getContext(), 32));
429  });
430  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
431  Type elemType = type.getFragmented().getElementType();
432  int64_t sizeM = type.getFragmented().getDimSize(0);
433  int64_t sizeN = type.getFragmented().getDimSize(1);
434 
435  unsigned numMembers;
436  if (elemType.isF32() || elemType.isInteger(32))
437  numMembers = sizeN / 2;
438  else if (elemType.isF16())
439  numMembers = sizeN / 4;
440  else
441  llvm_unreachable("unsupported type for warpgroup accumulator");
442 
443  SmallVector<Type> innerStructBody;
444  for (unsigned i = 0; i < numMembers; i++)
445  innerStructBody.push_back(elemType);
446  auto innerStructType =
447  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
448 
449  SmallVector<Type> structBody;
450  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
451  structBody.push_back(innerStructType);
452 
453  auto convertedType =
454  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
455  return converter.convertType(convertedType);
456  });
457  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
458  return converter.convertType(IntegerType::get(type.getContext(), 64));
459  });
460  converter.addConversion(
461  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
462  return converter.convertType(IntegerType::get(type.getContext(), 64));
463  });
464  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
465  return converter.convertType(
466  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
467  });
468  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
469  return LLVM::LLVMPointerType::get(type.getContext());
470  });
473  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474  target.addLegalDialect<::mlir::arith::ArithDialect>();
475  target.addLegalDialect<::mlir::memref::MemRefDialect>();
476  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
478  converter, patterns, target);
479  if (failed(applyPartialConversion(getOperation(), target,
480  std::move(patterns))))
481  signalPassFailure();
482  }
483 };
484 
485 /// Returns the constraints for the sparse MMA inline assembly instruction.
486 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
487  unsigned matBSize,
488  unsigned matCSize) {
489  std::string str;
490  llvm::raw_string_ostream ss(str);
491  for (unsigned i = 0; i < matCSize; i++)
492  ss << "=r,";
493  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
494  ss << "r,";
495  // The final operand is for the sparsity metadata.
496  // The sparsity selector appears as direct literal.
497  ss << "r";
498  return str;
499 }
500 
501 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
502 /// the given parameters. Note that this function doesn't do any validation,
503 /// it's expected that the provided parameters correspond to a valid
504 /// instruction.
505 static std::string buildMmaSparseAsmString(
506  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
507  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
510  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511  return NVVM::stringifyMMATypes(ptxType);
512  };
513 
514  std::string asmStr;
515  llvm::raw_string_ostream ss(asmStr);
516  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
517  << shape[2] << ".row.col.";
518 
519  if (overflow)
520  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
521 
522  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
523  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
524  unsigned asmArgIdx = 0;
525 
526  // The operand string is structured into sections `{matC elements...},
527  // {matA elements...}, {matB elements...}, {matC elements}`.
528  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529  ss << "{";
530  for (unsigned i = 0; i < arrSize; i++)
531  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
532  ss << "},";
533  }
534  ss << "$" << asmArgIdx++ << ",";
535  assert(metaDataSelector <= 1);
536  ss << "0x" << metaDataSelector << ";";
537  return asmStr;
538 }
539 
540 /// Builds an inline assembly operation corresponding to the specified MMA
541 /// sparse sync operation.
542 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
543  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
544  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
545  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
546  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
547  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
548  Type intrinsicResultType) {
549  auto asmDialectAttr =
550  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
551 
552  const unsigned matASize = unpackedAData.size();
553  const unsigned matBSize = unpackedB.size();
554  const unsigned matCSize = unpackedC.size();
555 
556  std::string asmStr = buildMmaSparseAsmString(
557  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
558  ptxTypeD, overflow, metadataSelector);
559  std::string constraintStr =
560  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
561 
562  SmallVector<Value> asmVals;
563  asmVals.reserve(matASize + matBSize + matCSize + 1);
564  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
565  llvm::append_range(asmVals, args);
566  asmVals.push_back(indexData);
567 
568  return b.create<LLVM::InlineAsmOp>(
569  /*resultTypes=*/intrinsicResultType,
570  /*operands=*/asmVals,
571  /*asm_string=*/asmStr,
572  /*constraints=*/constraintStr,
573  /*has_side_effects=*/true,
574  /*is_align_stack=*/false,
575  /*asm_dialect=*/asmDialectAttr,
576  /*operand_attrs=*/ArrayAttr());
577 }
578 
579 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
580 struct NVGPUMmaSparseSyncLowering
581  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
583 
584  LogicalResult
585  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
586  ConversionPatternRewriter &rewriter) const override {
587  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
588  // Get the shapes of the MMAMatrix type being used. The shapes will
589  // choose which intrinsic this op will be lowered to.
590  VectorType aType = op.getMatrixA().getType();
591  VectorType bType = op.getMatrixB().getType();
592  VectorType cType = op.getMatrixC().getType();
593 
594  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595  if (failed(ptxTypeA))
596  return op->emitOpError("failed to deduce operand PTX types");
597  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598  if (failed(ptxTypeB))
599  return op->emitOpError("failed to deduce operand PTX types");
600  std::optional<NVVM::MMATypes> ptxTypeC =
601  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
602  /*isAccumulator=*/true);
603  if (!ptxTypeC)
604  return op->emitError(
605  "could not infer the PTX type for the accumulator/result");
606 
607  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
608  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609  if (aType.getElementType().isF32() && !tf32Enabled)
610  return failure();
611 
612  // TODO: add an attribute to the op to customize this behavior.
613  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614  if (isa<IntegerType>(aType.getElementType()))
615  overflow = NVVM::MMAIntOverflow::satfinite;
616 
617  SmallVector<Value> matA =
618  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
619  SmallVector<Value> matB =
620  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
621  SmallVector<Value> matC =
622  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
623 
624  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
625  Type intrinsicResTy = inferIntrinsicResultType(
626  typeConverter->convertType(op->getResultTypes()[0]));
627 
628  // Bitcast the sparse metadata from vector<2xf16> to an i32.
629  Value sparseMetadata = adaptor.getSparseMetadata();
630  if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
631  return op->emitOpError() << "Expected metadata type to be LLVM "
632  "VectorType of 2 i16 elements";
633  sparseMetadata =
634  b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
635 
636  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
637  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
638  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
639  intrinsicResTy);
640  if (failed(intrinsicResult))
641  return failure();
642 
643  assert((*intrinsicResult).getNumResults() == 1 &&
644  "expected inline asm op returns a single LLVM struct type");
645  rewriter.replaceOp(
646  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
647  (*intrinsicResult)->getResult(0), rewriter));
648  return success();
649  }
650 };
651 
652 struct NVGPUAsyncCopyLowering
653  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
655  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
656 
657  LogicalResult
658  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
659  ConversionPatternRewriter &rewriter) const override {
660  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
661  Location loc = op.getLoc();
662  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
663  Value dstPtr =
664  getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
665  adaptor.getDstIndices(), rewriter);
666  FailureOr<unsigned> dstAddressSpace =
667  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
668  if (failed(dstAddressSpace))
669  return rewriter.notifyMatchFailure(
670  loc, "destination memref address space not convertible to integer");
671 
672  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
673  FailureOr<unsigned> srcAddressSpace =
674  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
675  if (failed(srcAddressSpace))
676  return rewriter.notifyMatchFailure(
677  loc, "source memref address space not convertible to integer");
678 
679  Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
680  adaptor.getSrcIndices(), rewriter);
681  // Intrinsics takes a global pointer so we need an address space cast.
682  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
684  scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
685  int64_t dstElements = adaptor.getDstElements().getZExtValue();
686  int64_t sizeInBytes =
687  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
688  // When the optional SrcElements argument is *not* present, the regular
689  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
690  // memory) to fill DstElements number of elements in the destination
691  // (shared memory).
692  Value srcBytes = adaptor.getSrcElements();
693  if (srcBytes) {
694  // When the optional SrcElements argument is present, the source (global
695  // memory) of CpAsyncOp is read only for SrcElements number of elements.
696  // The rest of the DstElements in the destination (shared memory) are
697  // filled with zeros.
698  Value c3I32 =
699  b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
700  Value bitwidth = b.create<LLVM::ConstantOp>(
701  b.getI32Type(),
702  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
703  Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
704  srcBytes = b.create<LLVM::LShrOp>(
705  b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
706  }
707  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
708  // 16 dst bytes.
709  NVVM::LoadCacheModifierKind cacheModifier =
710  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
711  ? NVVM::LoadCacheModifierKind::CG
712  : NVVM::LoadCacheModifierKind::CA;
713 
714  b.create<NVVM::CpAsyncOp>(
715  dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
716  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
717  srcBytes);
718 
719  // Drop the result token.
720  Value zero = b.create<LLVM::ConstantOp>(
721  IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
722  rewriter.replaceOp(op, zero);
723  return success();
724  }
725 };
726 
727 struct NVGPUAsyncCreateGroupLowering
728  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
730  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
731 
732  LogicalResult
733  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
734  ConversionPatternRewriter &rewriter) const override {
735  rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
736  // Drop the result token.
737  Value zero = rewriter.create<LLVM::ConstantOp>(
738  op->getLoc(), IntegerType::get(op.getContext(), 32),
739  rewriter.getI32IntegerAttr(0));
740  rewriter.replaceOp(op, zero);
741  return success();
742  }
743 };
744 
745 struct NVGPUAsyncWaitLowering
746  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
748  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
749 
750  LogicalResult
751  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
752  ConversionPatternRewriter &rewriter) const override {
753  // If numGroup is not present pick 0 as a conservative correct value.
754  int32_t numGroups = adaptor.getNumGroups().value_or(0);
755  rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
756  rewriter.eraseOp(op);
757  return success();
758  }
759 };
760 
761 /// Creates mbarrier object in shared memory
762 struct NVGPUMBarrierCreateLowering
763  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
765 
766  template <typename moduleT>
767  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
768  Operation *funcOp, moduleT moduleOp,
769  MemRefType barrierType) const {
770  SymbolTable symbolTable(moduleOp);
771  OpBuilder::InsertionGuard guard(rewriter);
772  rewriter.setInsertionPoint(&moduleOp.front());
773  auto global = rewriter.create<memref::GlobalOp>(
774  funcOp->getLoc(), "__mbarrier",
775  /*sym_visibility=*/rewriter.getStringAttr("private"),
776  /*type=*/barrierType,
777  /*initial_value=*/ElementsAttr(),
778  /*constant=*/false,
779  /*alignment=*/rewriter.getI64IntegerAttr(8));
780  symbolTable.insert(global);
781  return global;
782  }
783 
784  LogicalResult
785  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
786  ConversionPatternRewriter &rewriter) const override {
787  Operation *funcOp = op->getParentOp();
788  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
789  rewriter.getContext(), op.getBarriers().getType());
790 
791  memref::GlobalOp global;
792  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
793  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
794  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
795  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
796 
797  rewriter.setInsertionPoint(op);
798  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
799  global.getName());
800  return success();
801  }
802 };
803 
804 /// Base class for lowering mbarrier operations to nvvm intrinsics.
805 template <typename SourceOp>
806 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
807 public:
809  /// Returns the base pointer of the mbarrier object.
810  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
811  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
812  Value mbarId,
813  ConversionPatternRewriter &rewriter) const {
814  MemRefType mbarrierMemrefType =
815  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
817  b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
818  }
819 };
820 
821 struct NVGPUMBarrierGetLowering
822  : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
823  using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
824 
825  LogicalResult
826  matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
827  ConversionPatternRewriter &rewriter) const override {
828  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
829  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
830  rewriter.setInsertionPoint(op);
831  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
832  adaptor.getMbarId(), rewriter);
833  Type resType = op.getMbarrierPointer().getType();
834  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
835  return success();
836  }
837 };
838 
839 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
840 struct NVGPUMBarrierInitLowering
841  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
842  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
843 
844  LogicalResult
845  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
846  ConversionPatternRewriter &rewriter) const override {
847  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
848  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
849  rewriter.setInsertionPoint(op);
850  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
851  adaptor.getMbarId(), rewriter);
852  Value count = truncToI32(b, adaptor.getCount());
853  if (isMbarrierShared(mbarrierType)) {
854  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
855  op, barrier, count, adaptor.getPredicate());
856  } else {
857  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
858  adaptor.getPredicate());
859  }
860  return success();
861  }
862 };
863 
864 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
865 struct NVGPUMBarrierArriveLowering
866  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
867  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
868  LogicalResult
869  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
870  ConversionPatternRewriter &rewriter) const override {
871  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
872  Value barrier =
873  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
874  adaptor.getMbarId(), rewriter);
875  Type tokenType = getTypeConverter()->convertType(
876  nvgpu::MBarrierTokenType::get(op->getContext()));
877  if (isMbarrierShared(op.getBarriers().getType())) {
878  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
879  barrier);
880  } else {
881  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
882  barrier);
883  }
884  return success();
885  }
886 };
887 
888 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
889 /// `nvvm.mbarrier.arrive.nocomplete`
890 struct NVGPUMBarrierArriveNoCompleteLowering
891  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
892  using MBarrierBasePattern<
893  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
894  LogicalResult
895  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
896  ConversionPatternRewriter &rewriter) const override {
897  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
898  Value barrier =
899  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
900  adaptor.getMbarId(), rewriter);
901  Type tokenType = getTypeConverter()->convertType(
902  nvgpu::MBarrierTokenType::get(op->getContext()));
903  Value count = truncToI32(b, adaptor.getCount());
904  if (isMbarrierShared(op.getBarriers().getType())) {
905  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
906  op, tokenType, barrier, count);
907  } else {
908  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
909  op, tokenType, barrier, count);
910  }
911  return success();
912  }
913 };
914 
915 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
916 struct NVGPUMBarrierTestWaitLowering
917  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
918  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
919  LogicalResult
920  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
921  ConversionPatternRewriter &rewriter) const override {
922  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
923  Value barrier =
924  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
925  adaptor.getMbarId(), rewriter);
926  Type retType = rewriter.getI1Type();
927  if (isMbarrierShared(op.getBarriers().getType())) {
928  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
929  op, retType, barrier, adaptor.getToken());
930  } else {
931  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
932  op, retType, barrier, adaptor.getToken());
933  }
934  return success();
935  }
936 };
937 
938 struct NVGPUMBarrierArriveExpectTxLowering
939  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
940  using MBarrierBasePattern<
941  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
942  LogicalResult
943  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
944  ConversionPatternRewriter &rewriter) const override {
945  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
946  Value barrier =
947  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
948  adaptor.getMbarId(), rewriter);
949  Value txcount = truncToI32(b, adaptor.getTxcount());
950 
951  if (isMbarrierShared(op.getBarriers().getType())) {
952  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
953  op, barrier, txcount, adaptor.getPredicate());
954  return success();
955  }
956 
957  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
958  op, barrier, txcount, adaptor.getPredicate());
959  return success();
960  }
961 };
962 
963 struct NVGPUMBarrierTryWaitParityLowering
964  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
965  using MBarrierBasePattern<
966  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
967  LogicalResult
968  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
969  ConversionPatternRewriter &rewriter) const override {
970  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
971  Value barrier =
972  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
973  adaptor.getMbarId(), rewriter);
974  Value ticks = truncToI32(b, adaptor.getTicks());
975  Value phase =
976  b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
977 
978  if (isMbarrierShared(op.getBarriers().getType())) {
979  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
980  op, barrier, phase, ticks);
981  return success();
982  }
983 
984  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
985  phase, ticks);
986  return success();
987  }
988 };
989 
990 struct NVGPUTmaAsyncLoadOpLowering
991  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
992  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
993  LogicalResult
994  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
995  ConversionPatternRewriter &rewriter) const override {
996  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
997  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
998  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
999  adaptor.getDst(), {}, rewriter);
1000  Value barrier =
1001  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1002  adaptor.getMbarId(), rewriter);
1003 
1004  SmallVector<Value> coords = adaptor.getCoordinates();
1005  for (auto [index, value] : llvm::enumerate(coords)) {
1006  coords[index] = truncToI32(b, value);
1007  }
1008  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
1009  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1010  ValueRange{}, adaptor.getMulticastMask(), Value{},
1011  adaptor.getPredicate());
1012  return success();
1013  }
1014 };
1015 
1016 struct NVGPUTmaAsyncStoreOpLowering
1017  : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1018  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1019  LogicalResult
1020  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1021  ConversionPatternRewriter &rewriter) const override {
1022  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1023  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1024  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1025  adaptor.getSrc(), {}, rewriter);
1026  SmallVector<Value> coords = adaptor.getCoordinates();
1027  for (auto [index, value] : llvm::enumerate(coords)) {
1028  coords[index] = truncToI32(b, value);
1029  }
1030 
1031  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1032  op, adaptor.getTensorMapDescriptor(), dest, coords,
1033  adaptor.getPredicate());
1034  return success();
1035  }
1036 };
1037 
1038 struct NVGPUGenerateWarpgroupDescriptorLowering
1039  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1040  using ConvertOpToLLVMPattern<
1041  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1042 
1043  LogicalResult
1044  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1045  ConversionPatternRewriter &rewriter) const override {
1046 
1047  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1048 
1049  nvgpu::TensorMapSwizzleKind swizzleKind =
1050  op.getTensorMap().getType().getSwizzle();
1051 
1052  unsigned layout =
1053  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1054  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1055  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1056  : 1;
1057  unsigned swizzle =
1058  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1059  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1060  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1061  : 0;
1062 
1063  auto ti64 = b.getIntegerType(64);
1064  auto makeConst = [&](uint64_t index) -> Value {
1065  return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1066  };
1067  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1068  return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1069  };
1070  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1071  return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1072  };
1073  auto insertBit = [&](Value desc, Value val, int startBit) {
1074  return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1075  };
1076 
1077  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1078  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1079  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1080  uint64_t offsetVal = 0;
1081 
1082  Value strideDim = makeConst(strideDimVal);
1083  Value leadDim = makeConst(leadDimVal);
1084 
1085  Value baseAddr = getStridedElementPtr(
1086  op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1087  adaptor.getTensor(), {}, rewriter);
1088  Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1089  // Just use 14 bits for base address
1090  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1091 
1092  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1093  startLeadBit = 16, startBaseAddrBit = 0;
1094  Value dsc = makeConst(0);
1095  // // [62,64) swizzle type
1096  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1097  // // [49,52) base_offset
1098  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1099  // // [32,46) stride
1100  dsc = insertBit(dsc, strideDim, startStrideBit);
1101  // // [16,30) leading dimension
1102  dsc = insertBit(dsc, leadDim, startLeadBit);
1103  // // [0,14) start_address
1104  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1105 
1106  LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1107  << "leading_off:" << leadDimVal << "\t"
1108  << "stride_off :" << strideDimVal << "\t"
1109  << "base_offset:" << offsetVal << "\t"
1110  << "layout_type:" << swizzle << " ("
1111  << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1112  << ")\n start_addr : " << baseAddr << "\n");
1113 
1114  rewriter.replaceOp(op, dsc);
1115  return success();
1116  }
1117 };
1118 
1119 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1120  return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1121  b.getI32IntegerAttr(index));
1122 }
1123 
1124 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1125 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1126  // Enum is from CUDA driver API
1127  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1128  enum CUtensorMapDataTypeEnum {
1129  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1130  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1131  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1132  CU_TENSOR_MAP_DATA_TYPE_INT32,
1133  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1134  CU_TENSOR_MAP_DATA_TYPE_INT64,
1135  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1136  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1137  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1138  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1139  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1140  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1141  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1142  };
1143 
1144  if (type.isUnsignedInteger(8))
1145  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1146  if (type.isUnsignedInteger(16))
1147  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1148  if (type.isUnsignedInteger(32))
1149  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1150  if (type.isUnsignedInteger(64))
1151  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1152  if (type.isSignlessInteger(32))
1153  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1154  if (type.isSignlessInteger(64))
1155  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1156  if (type.isF16())
1157  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1158  if (type.isF32())
1159  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1160  if (type.isF64())
1161  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1162  if (type.isBF16())
1163  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1164 
1165  llvm_unreachable("Not supported data type");
1166 }
1167 
1168 struct NVGPUTmaCreateDescriptorOpLowering
1169  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1170  using ConvertOpToLLVMPattern<
1171  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1172  LogicalResult
1173  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1174  ConversionPatternRewriter &rewriter) const override {
1175  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1176  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1177  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1178 
1179  Value tensorElementType =
1180  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1181  auto promotedOperands = getTypeConverter()->promoteOperands(
1182  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1183 
1184  Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1185  makeI64Const(b, 5));
1186  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1187  Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1188  boxArrayPtr, makeI64Const(b, index));
1189  b.create<LLVM::StoreOp>(value, gep);
1190  }
1191 
1192  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1193  // Set Arguments for the function call
1194  SmallVector<Value> arguments;
1195  arguments.push_back(promotedOperands[0]); // rank
1196  arguments.push_back(promotedOperands[1]); // descriptor
1197  arguments.push_back(tensorElementType); // data type
1198  arguments.push_back(
1199  makeI64Const(b, (int)desc.getInterleave())); // interleave
1200  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1201  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1202  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1203  arguments.push_back(boxArrayPtr); // box dimensions
1204 
1205  // Set data types of the arguments
1206  SmallVector<Type> argTypes = {
1207  llvmInt64Type, /* int64_t tensorRank */
1208  llvmPointerType, /* ptr */
1209  llvmInt64Type, /* int64_t */
1210  llvmInt64Type, /* int64_t */
1211  llvmInt64Type, /* int64_t */
1212  llvmInt64Type, /* int64_t */
1213  llvmInt64Type, /* int64_t */
1214  llvmPointerType /* ptr */
1215  };
1216  FunctionCallBuilder hostRegisterCallBuilder = {
1217  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1218  Value tensorMap =
1219  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1220 
1221  rewriter.replaceOp(op, tensorMap);
1222  return success();
1223  }
1224 };
1225 
1226 struct NVGPUWarpgroupMmaOpLowering
1227  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1229 
1230  /// This is a helper class to generate required NVVM Ops for warp-group level
1231  /// matrix multiplication.
1232  /// When the given GEMM shape is larger than the shape of
1233  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1234  /// Op(s), group and execute them asynchronously. The class also handles
1235  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1236  /// create descriptors for each instruction.
1237  ///
1238  /// For example this is the case when the shape of GEMM is 128x128x128
1239  ///
1240  /// nvvm.wgmma.fence.aligned
1241  ///
1242  /// nvvm.wgmma.mma.async descA, descB
1243  /// iterate(descA, descB)
1244  /// nvvm.wgmma.mma.async descA, descB
1245  /// [6x times more]
1246  ///
1247  /// nvvm.wgmma.group.sync.aligned
1248  /// nvvm.wgmma.wait.group.sync [groupId]
1249  ///
1250  class WarpgroupGemm {
1251  nvgpu::WarpgroupMmaOp op;
1253  OpAdaptor adaptor;
1254 
1255  // Entire shape of the given Op
1256  int64_t totalM, totalN, totalK;
1257 
1258  // Shape of one wgmma instruction
1259  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1260 
1261  // Iteration counts for GEMM
1262  int iterationM = 0, iterationN = 0, iterationK = 0;
1263 
1264  /// The function returns the shape of wgmma instruction that is defined in
1265  /// PTX programming guide.
1266  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1267  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1268  wgmmaM = 64;
1269  wgmmaN = sizeN;
1270  if (inputElemType.isTF32()) {
1271  wgmmaK = 8;
1272  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1273  wgmmaK = 16;
1274  } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1275  inputElemType.isInteger(16)) {
1276  wgmmaK = 32;
1277  } else if (inputElemType.isInteger(1)) {
1278  wgmmaK = 256;
1279  } else {
1280  llvm_unreachable("msg: not supported K shape");
1281  }
1282  LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1283  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1284  }
1285 
1286  /// Generates WGMMATypesAttr from MLIR Type
1287  NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1288  bool useF32 = false) const {
1289  auto getWgmmaType = [=](Type elemType) {
1290  if (elemType.isF32() || elemType.isTF32())
1291  return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1292  if (elemType.isF16())
1293  return NVVM::WGMMATypes::f16;
1294  if (elemType.isBF16())
1295  return NVVM::WGMMATypes::bf16;
1296  if (isa<Float8E4M3FNType>(elemType))
1297  return NVVM::WGMMATypes::e4m3;
1298  if (isa<Float8E5M2Type>(elemType))
1299  return NVVM::WGMMATypes::e5m2;
1300  if (elemType.isInteger(1))
1301  return NVVM::WGMMATypes::b1;
1302  if (elemType.isInteger(8))
1303  return NVVM::WGMMATypes::s8;
1304  if (elemType.isUnsignedInteger(8))
1305  return NVVM::WGMMATypes::u8;
1306  if (elemType.isInteger(32))
1307  return NVVM::WGMMATypes::s32;
1308  llvm_unreachable("unsupported type");
1309  };
1310  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1311  }
1312 
1313  /// Generates layout attribute for the input matrix for wgmma instruction
1314  NVVM::MMALayoutAttr
1315  generateWgmmaLayout(std::optional<bool> transpose) const {
1316  if (transpose.value_or(false))
1317  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1318  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1319  }
1320 
1321  /// Generates shape attribute for wgmma instruction
1322  NVVM::MMAShapeAttr generateWgmmaShape() const {
1323  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1324  }
1325 
1326  /// Generates scale attributes of output matrix for wgmma instruction
1327  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1328  return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1329  NVVM::WGMMAScaleOut::one);
1330  }
1331  /// Generates scale attributes of input matrix for wgmma instruction
1332  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1333  return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1334  NVVM::WGMMAScaleIn::one);
1335  }
1336 
1337  /// Basic function to generate Add
1338  Value makeAdd(Value lhs, Value rhs) {
1339  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1340  };
1341 
1342  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1343  /// Currently, it only handles row-major.
1344  ///
1345  /// It moves the pointer like below for [128][64] size:
1346  /// +2 +4 +6
1347  /// ↓ ↓ ↓
1348  /// descA ---> +--+--+--+--+
1349  /// |->|->|->|->|
1350  /// | | | | |
1351  /// | | | | |
1352  /// | | | | |
1353  /// descA+512---> +-----------+
1354  /// | | | | |
1355  /// | | | | |
1356  /// | | | | |
1357  /// | | | | |
1358  /// +-----------+
1359  ///
1360  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1361  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1362  Type elemA = matrixTypeA.getElementType();
1363  int byte = elemA.getIntOrFloatBitWidth() / 8;
1364  int tileShapeA = matrixTypeA.getDimSize(1);
1365  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1366  incrementVal = incrementVal >> exclude4LSB;
1367  LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1368  << "] [wgmma descriptors] Descriptor A + "
1369  << incrementVal << " | \t ");
1370  if (!incrementVal)
1371  return desc;
1372  return makeAdd(desc, makeI64Const(b, incrementVal));
1373  }
1374 
1375  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1376  /// Currently, it only handles column-major.
1377  ///
1378  /// It moves the pointer like below for [128][64] size:
1379  /// descB ---> +--+--+--+--+--+--+--+--+
1380  /// |↓ | | | | | | | |
1381  /// |↓ | | | | | | | |
1382  /// |↓ | | | | | | | |
1383  /// |↓ | | | | | | | |
1384  /// +--+--+--+--+--+--+--+--+
1385  ///
1386  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1387  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1388  Type elemB = matrixTypeB.getElementType();
1389  int byte = elemB.getIntOrFloatBitWidth() / 8;
1390  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1391  incrementVal = incrementVal >> exclude4LSB;
1392  LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1393  if (!incrementVal)
1394  return desc;
1395  return makeAdd(desc, makeI64Const(b, incrementVal));
1396  }
1397 
1398  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1399  /// descriptors and arranges them based on induction variables: i, j, and k.
1400  Value generateWgmma(int i, int j, int k, Value matrixC) {
1401  LLVM_DEBUG(DBGS() << "\t wgmma."
1402  << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1403  << "(A[" << (iterationM * wgmmaM) << ":"
1404  << (iterationM * wgmmaM) + wgmmaM << "]["
1405  << (iterationK * wgmmaK) << ":"
1406  << (iterationK * wgmmaK + wgmmaK) << "] * "
1407  << " B[" << (iterationK * wgmmaK) << ":"
1408  << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1409  << wgmmaN << "])\n");
1410 
1411  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1412  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1413 
1414  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1415  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1416 
1417  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1418  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1419 
1420  Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1421  NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1422 
1423  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1424  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1425  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1426  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1427  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1428 
1429  auto overflow = NVVM::MMAIntOverflowAttr::get(
1430  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1431 
1432  return b.create<NVVM::WgmmaMmaAsyncOp>(
1433  matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1434  itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1435  overflow);
1436  }
1437 
1438  /// Generates multiple wgmma instructions to complete the given GEMM shape
1439  Value generateWgmmaGroup() {
1440  Value wgmmaResult =
1441  b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType());
1442 
1443  // Perform GEMM
1444  SmallVector<Value> wgmmaResults;
1445  for (int i = 0; i < iterationM; ++i) {
1446  Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1447  for (int j = 0; j < iterationN; ++j)
1448  for (int k = 0; k < iterationK; ++k)
1449  matrixC = generateWgmma(i, j, k, matrixC);
1450  wgmmaResults.push_back(matrixC);
1451  }
1452  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1453  wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1454  wgmmaResult, matrix, idx);
1455  }
1456  return wgmmaResult;
1457  }
1458 
1459  public:
1460  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1461  OpAdaptor adaptor)
1462  : op(op), b(b), adaptor(adaptor) {
1463  // Find the entire GEMM Shape
1464  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1465  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1466  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1467  LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1468  << "] += A[" << totalM << "][" << totalK << "] * B["
1469  << totalK << "][" << totalN << "] ---===\n");
1470 
1471  // Find the shape for one wgmma instruction
1472  findWgmmaShape(
1473  totalM, totalN,
1474  op.getDescriptorA().getType().getTensor().getElementType());
1475 
1476  // Iterations counts to complete the given shape with wgmma shape
1477  iterationM = totalM / wgmmaM;
1478  iterationN = totalN / wgmmaN;
1479  iterationK = totalK / wgmmaK;
1480  }
1481 
1482  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1483  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1484  /// instructions and group synchronization, as well as waiting
1485  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1486  /// (WgmmaWaitGroupSyncOp) after the instructions.
1487  Value generateWarpgroupMma() {
1488  b.create<NVVM::WgmmaFenceAlignedOp>();
1489  Value wgmmaResult = generateWgmmaGroup();
1490  b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1491  b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1492  return wgmmaResult;
1493  }
1494  };
1495  LogicalResult
1496  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1497  ConversionPatternRewriter &rewriter) const override {
1498  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1499 
1500  // Step 1. Build a helper class
1501  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1502 
1503  // Step 2. Get the entire GEMM Shape
1504  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1505 
1506  // Step 3. Replace fragmented result struct with the op results
1507  rewriter.replaceOp(op, wgmmaResult);
1508  return success();
1509  }
1510 };
1511 
1512 struct NVGPUWarpgroupMmaStoreOpLowering
1513  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1514  using ConvertOpToLLVMPattern<
1515  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1516 
1517  /// This function stores a fragmented register matrix owned by a warp group
1518  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1519  /// of a struct.
1520  /// Here is what each threads (T) holds, each `d` is struct value with a
1521  /// number.
1522  ///
1523  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1524  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1525  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1526  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1527  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1528  ///
1529  /// Matrix-D:
1530  /// +______________________________________________________________________+
1531  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1532  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1533  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1534  /// ..| .........|.........|.........|.........|........|...........|........|
1535  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1536  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1537  /// ..| .........|.........|.........|.........|........|...........|........|
1538  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1539  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1540  /// ..| .........|.........|.........|.........|........|...........|........|
1541  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1542  /// ..| .........|.........|.........|.........|........|...........|........|
1543  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1544  /// ..| .........|.........|.........|.........|........|...........|........|
1545  /// +______________________________________________________________________+
1546  ///
1547  /// \param rewriter: The pattern rewriter.
1548  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1549  /// matrix). It is holded by a thread and a struct with 64 elements.
1550  /// \param dstMemref: The memref where the registers will be stored.
1551  /// \param offset: the offset within the memref where the registers will be
1552  /// stored.
1553  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1554  TypedValue<MemRefType> dstMemref,
1555  int offset) const {
1556  Type i32 = b.getI32Type();
1557 
1558  auto makeConst = [&](int32_t index) -> Value {
1559  return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1560  };
1561  Value c1 = makeConst(1);
1562  Value c2 = makeConst(2);
1563  Value c4 = makeConst(4);
1564  Value c8 = makeConst(8);
1565  Value c16 = makeConst(16);
1566  Value warpSize = makeConst(kWarpSize);
1567 
1568  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1569  return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1570  };
1571  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1572  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1573  };
1574 
1575  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1577  Type it = b.getIndexType();
1578  Value idx = b.create<arith::IndexCastOp>(it, x);
1579  Value idy0 = b.create<arith::IndexCastOp>(it, y);
1580  Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1581  Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1582  Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1583  b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1584  b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1585  };
1586 
1587  Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1588  Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1589  Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1590  Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1591  Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1592 
1593  Value tj = makeMul(lane4modId, c2);
1594  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1595  if (offset)
1596  ti = makeAdd(ti, makeConst(offset));
1597 
1598  auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1599 
1600  // Number of 32-bit registers owns per thread
1601  constexpr unsigned numAdjacentRegisters = 2;
1602  // Number of 8x8 matrices one below another per warp
1603  constexpr unsigned numStackedMatrices = 2;
1604 
1605  size_t storeCount = (structType.getBody().size() /
1606  (numStackedMatrices * numAdjacentRegisters));
1607 
1608  for (size_t i = 0; i < numStackedMatrices; ++i) {
1609  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1610  for (size_t j = 0; j < storeCount; ++j) {
1611  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1612  size_t structIndex = (i * numAdjacentRegisters) +
1613  (j * (numStackedMatrices * numAdjacentRegisters));
1614  makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1615  }
1616  }
1617  }
1618 
1619  LogicalResult
1620  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1621  ConversionPatternRewriter &rewriter) const override {
1622  int offset = 0;
1623  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1624  Value matriDValue = adaptor.getMatrixD();
1625  auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1626  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1627  auto structType = cast<LLVM::LLVMStructType>(matrixD);
1628  Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1629  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1630  offset += structType.getBody().size();
1631  }
1632  rewriter.eraseOp(op);
1633  return success();
1634  }
1635 };
1636 
1637 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1638  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1639  using ConvertOpToLLVMPattern<
1640  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1641  LogicalResult
1642  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1643  ConversionPatternRewriter &rewriter) const override {
1644  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1645  LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1646  getTypeConverter()->convertType(op.getMatrixC().getType()));
1647  Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1648  .getBody()
1649  .front();
1650  Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1651  Value packStruct = b.create<LLVM::PoisonOp>(packStructType);
1652  SmallVector<Value> innerStructs;
1653  // Unpack the structs and set all values to zero
1654  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1655  auto structType = cast<LLVM::LLVMStructType>(s);
1656  Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1657  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1658  structValue = b.create<LLVM::InsertValueOp>(
1659  structType, structValue, zero, ArrayRef<int64_t>({i}));
1660  }
1661  innerStructs.push_back(structValue);
1662  }
1663  // Pack the inner structs into a single struct
1664  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1665  packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1666  packStruct, matrix, idx);
1667  }
1668  rewriter.replaceOp(op, packStruct);
1669  return success();
1670  }
1671 };
1672 
1673 struct NVGPUTmaFenceOpLowering
1674  : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1676  LogicalResult
1677  matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1678  ConversionPatternRewriter &rewriter) const override {
1679  MLIRContext *ctx = op.getContext();
1680  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1681  auto i32Ty = b.getI32Type();
1682  Value tensormapSize =
1683  b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128));
1684 
1685  auto memscope =
1686  NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1687 
1688  rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1689  op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1690 
1691  return success();
1692  }
1693 };
1694 
1695 struct NVGPUTmaPrefetchOpLowering
1696  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1698  LogicalResult
1699  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1700  ConversionPatternRewriter &rewriter) const override {
1701  rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1702  op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1703  return success();
1704  }
1705 };
1706 
1707 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1709  LogicalResult
1710  matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1711  ConversionPatternRewriter &rewriter) const override {
1712  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1713  auto i64Ty = b.getI64Type();
1714  auto f32Ty = b.getF32Type();
1715  VectorType inTy = op.getIn().getType();
1716  // apply rcp.approx.ftz.f on each element in vector.
1717  auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1718  Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy);
1719  int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1720  for (int i = 0; i < numElems; i++) {
1721  Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
1722  Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
1723  Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1724  ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1725  }
1726  return ret1DVec;
1727  };
1728  if (inTy.getRank() == 1) {
1729  rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1730  return success();
1731  }
1733  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1734  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1735  OpAdaptor adaptor(operands);
1736  return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1737  },
1738  rewriter);
1739  }
1740 };
1741 } // namespace
1742 
1744  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1745  patterns.add<
1746  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1747  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1748  NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1749  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1750  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1751  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1752  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1753  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1754  NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1755  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1756  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1757  NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1758  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1759  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1760  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1761  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1762  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1763  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1764  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1765  NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1766 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:51
#define DBGSE()
Definition: NVGPUToNVVM.cpp:36
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:61
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:47
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
#define DBGS()
Definition: NVGPUToNVVM.cpp:35
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:61
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
FloatType getF16Type()
Definition: Builders.cpp:39
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
IntegerType getI8Type()
Definition: Builders.cpp:59
FloatType getF64Type()
Definition: Builders.cpp:45
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isTF32() const
Definition: Types.cpp:39
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.