MLIR  20.0.0git
NVGPUToNVVM.cpp
Go to the documentation of this file.
1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
37 
38 namespace mlir {
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 /// Number of bits that needs to be excluded when building matrix descriptor for
46 /// wgmma operations.
47 constexpr int exclude4LSB = 4;
48 
49 /// GPU has 32 bit registers, this function truncates values when larger width
50 /// is not needed.
52  Type type = value.getType();
53  assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
54  if (type.getIntOrFloatBitWidth() <= 32)
55  return value;
56  return b.create<LLVM::TruncOp>(b.getI32Type(), value);
57 }
58 
59 /// Returns the type for the intrinsic given the vectorResultType of the
60 /// `gpu.mma.sync` operation.
61 static Type inferIntrinsicResultType(Type vectorResultType) {
62  MLIRContext *ctx = vectorResultType.getContext();
63  auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
64  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
65  auto i32Ty = IntegerType::get(ctx, 32);
66  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
67  Type f64Ty = Float64Type::get(ctx);
68  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
69  Type f32Ty = Float32Type::get(ctx);
70  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
71  if (a.getElementType() == f16x2Ty) {
73  ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
74  }
75  if (a.getElementType() == i32x2Ty) {
77  ctx,
78  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
79  }
80  if (a.getElementType() == f64x2Ty) {
81  return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
82  }
83  if (a.getElementType() == f32x2Ty) {
85  ctx,
86  SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
87  }
88  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
90  ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
91  }
92  return vectorResultType;
93 }
94 
95 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
96 /// always an LLVM struct) into a fragment that is compatible with the vector
97 /// type of this operation. This involves extracting elements from the struct
98 /// and inserting them into an LLVM array. These extra data-movement
99 /// operations should be canonicalized away by the LLVM backend.
100 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
101  Type resultType, Value intrinsicResult,
102  RewriterBase &rewriter) {
103  MLIRContext *ctx = rewriter.getContext();
104  auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
106  Type i32Ty = rewriter.getI32Type();
107  Type f32Ty = rewriter.getF32Type();
108  Type f64Ty = rewriter.getF64Type();
109  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
110  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
111  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
112  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
113  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
114 
115  auto makeConst = [&](int32_t index) -> Value {
116  return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
117  rewriter.getI32IntegerAttr(index));
118  };
119 
120  if (arrayType) {
121  SmallVector<Value, 4> elements;
122 
123  // The intrinsic returns 32-bit wide elements in a form which can be
124  // directly bitcasted and inserted into the result vector.
125  if (arrayType.getElementType() == f16x2Ty ||
126  arrayType.getElementType() == f32x1Ty) {
127  for (unsigned i = 0; i < structType.getBody().size(); i++) {
128  Value el =
129  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
130  el = rewriter.createOrFold<LLVM::BitcastOp>(
131  loc, arrayType.getElementType(), el);
132  elements.push_back(el);
133  }
134  }
135 
136  // The intrinsic returns i32, f64, and f32 values as individual scalars,
137  // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
138  // need to extract them from the struct and pack them into the 64-bit wide
139  // rows of the vector result.
140  if (arrayType.getElementType() == i32x2Ty ||
141  arrayType.getElementType() == f64x2Ty ||
142  arrayType.getElementType() == f32x2Ty) {
143 
144  for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145  Value vec =
146  rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
147  Value x1 =
148  rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149  Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
150  i * 2 + 1);
151  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
152  x1, makeConst(0));
153  vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
154  x2, makeConst(1));
155  elements.push_back(vec);
156  }
157  }
158 
159  // Create the final vectorized result.
160  Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
161  for (const auto &el : llvm::enumerate(elements)) {
162  result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
163  el.index());
164  }
165  return result;
166  }
167 
168  return intrinsicResult;
169 }
170 
171 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
172 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
173 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
174 /// scalars of certain types. This function helps unpack the `vector` arguments
175 /// and cast them to the types expected by `nvvm.mma.sync`.
177  Value operand,
178  NVVM::MMATypes operandPtxType) {
179  SmallVector<Value> result;
180  Type i32Ty = b.getI32Type();
181  Type f64Ty = b.getF64Type();
182  Type f32Ty = b.getF32Type();
183  Type i64Ty = b.getI64Type();
184  Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
185  Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
186  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
187  auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
188 
189  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190  Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
191 
192  // For 4xi8 vectors, the intrinsic expects these to be provided as i32
193  // scalar types.
194  if (arrayTy.getElementType() == i8x4Ty ||
195  arrayTy.getElementType() == i4x8Ty ||
196  (arrayTy.getElementType() == f32x1Ty &&
197  operandPtxType == NVVM::MMATypes::tf32)) {
198  result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
199  continue;
200  }
201 
202  // For some element types (i32, f32, f64), we need to unpack the inner
203  // vector/array type as well because the intrinsic expects individual
204  // scalars to be provided.
205  VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206  if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207  innerArrayTy.getElementType() == f64Ty ||
208  innerArrayTy.getElementType() == f32Ty)) {
209  for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210  idx < innerSize; idx++) {
211  result.push_back(b.create<LLVM::ExtractElementOp>(
212  toUse,
213  b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
214  }
215  continue;
216  }
217  result.push_back(toUse);
218  }
219  return result;
220 }
221 
222 /// Returns whether mbarrier object has shared memory address space.
223 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
224  return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225  barrierType.getMemorySpace()));
226 }
227 
228 /// Returns the memory space attribute of the mbarrier object.
230  nvgpu::MBarrierGroupType barrierType) {
231  Attribute memorySpace = {};
232  if (isMbarrierShared(barrierType)) {
233  memorySpace =
234  IntegerAttr::get(IntegerType::get(context, 64),
235  nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
236  }
237  return memorySpace;
238 }
239 
240 /// Returns memref type of the mbarrier object. The type is defined in the
241 /// MBarrierGroupType.
242 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
243  nvgpu::MBarrierGroupType barrierType) {
244  Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
245  MemRefLayoutAttrInterface layout;
246  return MemRefType::get({barrierType.getNumBarriers()},
247  IntegerType::get(context, 64), layout, memorySpace);
248 }
249 
250 namespace {
251 
252 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
254 
255  LogicalResult
256  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
257  ConversionPatternRewriter &rewriter) const override {
258  MLIRContext *ctx = getContext();
259  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
260 
261  // The result type of ldmatrix will always be a struct of 32bit integer
262  // registers if more than one 32bit value is returned. Otherwise, the result
263  // is a single i32. The result type of the GPU operation is always a vector
264  // of shape (NumRegisters, VectorRegister) where VectorRegister is the
265  // vector type of the result and always 32 bits long. We bitcast the result
266  // of the NVVM::LdMatrix to this vector type.
267  auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268  if (!vectorResultType) {
269  return failure();
270  }
271  Type innerVectorType = LLVM::getFixedVectorType(
272  vectorResultType.getElementType(), vectorResultType.getDimSize(1));
273 
274  int64_t num32BitRegs = vectorResultType.getDimSize(0);
275 
276  Type ldMatrixResultType;
277  if (num32BitRegs > 1) {
278  ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
279  ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
280  } else {
281  ldMatrixResultType = rewriter.getI32Type();
282  }
283 
284  auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285  Value srcPtr =
286  getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287  adaptor.getIndices(), rewriter);
288  Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289  ldMatrixResultType, srcPtr,
290  /*num=*/op.getNumTiles(),
291  /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
292  : NVVM::MMALayout::row);
293 
294  // The ldmatrix operation returns either a single i32 value or a struct of
295  // i32 values. Here we unpack those values and cast them back to their
296  // actual vector type (still of width 32b) and repack them into a result
297  // struct.
298  Type finalResultType = typeConverter->convertType(vectorResultType);
299  Value result = b.create<LLVM::UndefOp>(finalResultType);
300  for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301  Value i32Register =
302  num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303  : ldMatrixResult;
304  Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305  result = b.create<LLVM::InsertValueOp>(result, casted, i);
306  }
307 
308  rewriter.replaceOp(op, result);
309  return success();
310  }
311 };
312 
313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314 /// enum).
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316  Type elType = getElementTypeOrSelf(t);
317  if (elType.isInteger(8))
318  return NVVM::MMATypes::s8;
319  if (elType.isInteger(4))
320  return NVVM::MMATypes::s4;
321  if (elType.isF16())
322  return NVVM::MMATypes::f16;
323  if (elType.isF64())
324  return NVVM::MMATypes::f64;
325  if (elType.isF32())
326  return NVVM::MMATypes::tf32;
327  return failure();
328 }
329 
330 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
332 
333  LogicalResult
334  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override {
336  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337  // Get the shapes of the MMAMatrix type being used. The shapes will
338  // choose which intrinsic this op will be lowered to.
339  VectorType aType = op.getMatrixA().getType();
340  VectorType bType = op.getMatrixA().getType();
341  VectorType cType = op.getMatrixC().getType();
342 
343  std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344 
345  // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347  if (aType.getElementType().isF32() && !tf32Enabled)
348  return failure();
349 
350  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351  if (failed(ptxTypeA))
352  return op->emitOpError("failed to deduce operand PTX types");
353  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354  if (failed(ptxTypeB))
355  return op->emitOpError("failed to deduce operand PTX types");
356  std::optional<NVVM::MMATypes> ptxTypeC =
357  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358  /*isAccumulator=*/true);
359  if (!ptxTypeC)
360  return op->emitError(
361  "could not infer the PTX type for the accumulator/result");
362 
363  // TODO: add an attribute to the op to customize this behavior.
364  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365  if (isa<IntegerType>(aType.getElementType()))
366  overflow = NVVM::MMAIntOverflow::satfinite;
367 
368  SmallVector<Value> matA =
369  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370  SmallVector<Value> matB =
371  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372  SmallVector<Value> matC =
373  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374 
375  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376  Type intrinsicResTy = inferIntrinsicResultType(
377  typeConverter->convertType(op->getResultTypes()[0]));
378  Value intrinsicResult = b.create<NVVM::MmaOp>(
379  intrinsicResTy, matA, matB, matC,
380  /*shape=*/gemmShape,
381  /*b1Op=*/std::nullopt,
382  /*intOverflow=*/overflow,
383  /*multiplicandPtxTypes=*/
384  std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385  /*multiplicandLayouts=*/
386  std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387  NVVM::MMALayout::col});
388  rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389  desiredRetTy, intrinsicResult,
390  rewriter));
391  return success();
392  }
393 };
394 
395 struct ConvertNVGPUToNVVMPass
396  : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397  using Base::Base;
398 
399  void getDependentDialects(DialectRegistry &registry) const override {
400  registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401  arith::ArithDialect>();
402  }
403 
404  void runOnOperation() override {
406  RewritePatternSet patterns(&getContext());
407  LLVMTypeConverter converter(&getContext(), options);
408  IRRewriter rewriter(&getContext());
410  converter, [](gpu::AddressSpace space) -> unsigned {
411  switch (space) {
412  case gpu::AddressSpace::Global:
413  return static_cast<unsigned>(
415  case gpu::AddressSpace::Workgroup:
416  return static_cast<unsigned>(
418  case gpu::AddressSpace::Private:
419  return 0;
420  }
421  llvm_unreachable("unknown address space enum value");
422  return 0;
423  });
424  /// device-side async tokens cannot be materialized in nvvm. We just
425  /// convert them to a dummy i32 type in order to easily drop them during
426  /// conversion.
427  converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
428  return converter.convertType(IntegerType::get(type.getContext(), 32));
429  });
430  converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
431  Type elemType = type.getFragmented().getElementType();
432  int64_t sizeM = type.getFragmented().getDimSize(0);
433  int64_t sizeN = type.getFragmented().getDimSize(1);
434 
435  unsigned numMembers;
436  if (elemType.isF32() || elemType.isInteger(32))
437  numMembers = sizeN / 2;
438  else if (elemType.isF16())
439  numMembers = sizeN / 4;
440  else
441  llvm_unreachable("unsupported type for warpgroup accumulator");
442 
443  SmallVector<Type> innerStructBody;
444  for (unsigned i = 0; i < numMembers; i++)
445  innerStructBody.push_back(elemType);
446  auto innerStructType =
447  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
448 
449  SmallVector<Type> structBody;
450  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
451  structBody.push_back(innerStructType);
452 
453  auto convertedType =
454  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
455  return converter.convertType(convertedType);
456  });
457  converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
458  return converter.convertType(IntegerType::get(type.getContext(), 64));
459  });
460  converter.addConversion(
461  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
462  return converter.convertType(IntegerType::get(type.getContext(), 64));
463  });
464  converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
465  return converter.convertType(
466  nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
467  });
468  converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
469  return LLVM::LLVMPointerType::get(type.getContext());
470  });
471  populateNVGPUToNVVMConversionPatterns(converter, patterns);
473  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474  target.addLegalDialect<::mlir::arith::ArithDialect>();
475  target.addLegalDialect<::mlir::memref::MemRefDialect>();
476  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
478  converter, patterns, target);
479  if (failed(applyPartialConversion(getOperation(), target,
480  std::move(patterns))))
481  signalPassFailure();
482  }
483 };
484 
485 /// Returns the constraints for the sparse MMA inline assembly instruction.
486 static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
487  unsigned matBSize,
488  unsigned matCSize) {
489  std::string str;
490  llvm::raw_string_ostream ss(str);
491  for (unsigned i = 0; i < matCSize; i++)
492  ss << "=r,";
493  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
494  ss << "r,";
495  // The final operand is for the sparsity metadata.
496  // The sparsity selector appears as direct literal.
497  ss << "r";
498  ss.flush();
499  return str;
500 }
501 
502 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
503 /// the given parameters. Note that this function doesn't do any validation,
504 /// it's expected that the provided parameters correspond to a valid
505 /// instruction.
506 static std::string buildMmaSparseAsmString(
507  const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
508  unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
509  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
510  std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
511  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
512  return NVVM::stringifyMMATypes(ptxType);
513  };
514 
515  std::string asmStr;
516  llvm::raw_string_ostream ss(asmStr);
517  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
518  << shape[2] << ".row.col.";
519 
520  if (overflow)
521  ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
522 
523  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
524  << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
525  unsigned asmArgIdx = 0;
526 
527  // The operand string is structured into sections `{matC elements...},
528  // {matA elements...}, {matB elements...}, {matC elements}`.
529  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
530  ss << "{";
531  for (unsigned i = 0; i < arrSize; i++)
532  ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
533  ss << "},";
534  }
535  ss << "$" << asmArgIdx++ << ",";
536  assert(metaDataSelector <= 1);
537  ss << "0x" << metaDataSelector << ";";
538  ss.flush();
539  return asmStr;
540 }
541 
542 /// Builds an inline assembly operation corresponding to the specified MMA
543 /// sparse sync operation.
544 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
545  ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
546  NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
547  std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
548  ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
549  int64_t metadataSelector, const std::array<int64_t, 3> &shape,
550  Type intrinsicResultType) {
551  auto asmDialectAttr =
552  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
553 
554  const unsigned matASize = unpackedAData.size();
555  const unsigned matBSize = unpackedB.size();
556  const unsigned matCSize = unpackedC.size();
557 
558  std::string asmStr = buildMmaSparseAsmString(
559  shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
560  ptxTypeD, overflow, metadataSelector);
561  std::string constraintStr =
562  buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
563 
564  SmallVector<Value> asmVals;
565  asmVals.reserve(matASize + matBSize + matCSize + 1);
566  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
567  llvm::append_range(asmVals, args);
568  asmVals.push_back(indexData);
569 
570  return b.create<LLVM::InlineAsmOp>(
571  /*resultTypes=*/intrinsicResultType,
572  /*operands=*/asmVals,
573  /*asm_string=*/asmStr,
574  /*constraints=*/constraintStr,
575  /*has_side_effects=*/true,
576  /*is_align_stack=*/false,
577  /*asm_dialect=*/asmDialectAttr,
578  /*operand_attrs=*/ArrayAttr());
579 }
580 
581 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
582 struct NVGPUMmaSparseSyncLowering
583  : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
585 
586  LogicalResult
587  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
588  ConversionPatternRewriter &rewriter) const override {
589  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
590  // Get the shapes of the MMAMatrix type being used. The shapes will
591  // choose which intrinsic this op will be lowered to.
592  VectorType aType = op.getMatrixA().getType();
593  VectorType bType = op.getMatrixB().getType();
594  VectorType cType = op.getMatrixC().getType();
595 
596  FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
597  if (failed(ptxTypeA))
598  return op->emitOpError("failed to deduce operand PTX types");
599  FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
600  if (failed(ptxTypeB))
601  return op->emitOpError("failed to deduce operand PTX types");
602  std::optional<NVVM::MMATypes> ptxTypeC =
603  NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
604  /*isAccumulator=*/true);
605  if (!ptxTypeC)
606  return op->emitError(
607  "could not infer the PTX type for the accumulator/result");
608 
609  // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
610  bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
611  if (aType.getElementType().isF32() && !tf32Enabled)
612  return failure();
613 
614  // TODO: add an attribute to the op to customize this behavior.
615  std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
616  if (isa<IntegerType>(aType.getElementType()))
617  overflow = NVVM::MMAIntOverflow::satfinite;
618 
619  SmallVector<Value> matA =
620  unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
621  SmallVector<Value> matB =
622  unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
623  SmallVector<Value> matC =
624  unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
625 
626  Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
627  Type intrinsicResTy = inferIntrinsicResultType(
628  typeConverter->convertType(op->getResultTypes()[0]));
629 
630  // Bitcast the sparse metadata from vector<2xf16> to an i32.
631  Value sparseMetadata = adaptor.getSparseMetadata();
632  if (sparseMetadata.getType() !=
633  LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
634  return op->emitOpError() << "Expected metadata type to be LLVM "
635  "VectorType of 2 i16 elements";
636  sparseMetadata =
637  b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
638 
639  FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
640  b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
641  matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
642  intrinsicResTy);
643  if (failed(intrinsicResult))
644  return failure();
645 
646  assert((*intrinsicResult).getNumResults() == 1 &&
647  "expected inline asm op returns a single LLVM struct type");
648  rewriter.replaceOp(
649  op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
650  (*intrinsicResult)->getResult(0), rewriter));
651  return success();
652  }
653 };
654 
655 struct NVGPUAsyncCopyLowering
656  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
658  nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
659 
660  LogicalResult
661  matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
662  ConversionPatternRewriter &rewriter) const override {
663  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
664  Location loc = op.getLoc();
665  auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
666  Value dstPtr =
667  getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
668  adaptor.getDstIndices(), rewriter);
669  FailureOr<unsigned> dstAddressSpace =
670  getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
671  if (failed(dstAddressSpace))
672  return rewriter.notifyMatchFailure(
673  loc, "destination memref address space not convertible to integer");
674 
675  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
676  FailureOr<unsigned> srcAddressSpace =
677  getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
678  if (failed(srcAddressSpace))
679  return rewriter.notifyMatchFailure(
680  loc, "source memref address space not convertible to integer");
681 
682  Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
683  adaptor.getSrcIndices(), rewriter);
684  // Intrinsics takes a global pointer so we need an address space cast.
685  auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
687  scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
688  int64_t dstElements = adaptor.getDstElements().getZExtValue();
689  int64_t sizeInBytes =
690  (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
691  // When the optional SrcElements argument is *not* present, the regular
692  // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
693  // memory) to fill DstElements number of elements in the destination
694  // (shared memory).
695  Value srcBytes = adaptor.getSrcElements();
696  if (srcBytes) {
697  // When the optional SrcElements argument is present, the source (global
698  // memory) of CpAsyncOp is read only for SrcElements number of elements.
699  // The rest of the DstElements in the destination (shared memory) are
700  // filled with zeros.
701  Value c3I32 =
702  b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
703  Value bitwidth = b.create<LLVM::ConstantOp>(
704  b.getI32Type(),
705  b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
706  Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
707  srcBytes = b.create<LLVM::LShrOp>(
708  b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
709  }
710  // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
711  // 16 dst bytes.
712  NVVM::LoadCacheModifierKind cacheModifier =
713  (op.getBypassL1().value_or(false) && sizeInBytes == 16)
714  ? NVVM::LoadCacheModifierKind::CG
715  : NVVM::LoadCacheModifierKind::CA;
716 
717  b.create<NVVM::CpAsyncOp>(
718  dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
719  NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
720  srcBytes);
721 
722  // Drop the result token.
723  Value zero = b.create<LLVM::ConstantOp>(
724  IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
725  rewriter.replaceOp(op, zero);
726  return success();
727  }
728 };
729 
730 struct NVGPUAsyncCreateGroupLowering
731  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
733  nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
734 
735  LogicalResult
736  matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
737  ConversionPatternRewriter &rewriter) const override {
738  rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
739  // Drop the result token.
740  Value zero = rewriter.create<LLVM::ConstantOp>(
741  op->getLoc(), IntegerType::get(op.getContext(), 32),
742  rewriter.getI32IntegerAttr(0));
743  rewriter.replaceOp(op, zero);
744  return success();
745  }
746 };
747 
748 struct NVGPUAsyncWaitLowering
749  : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
751  nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
752 
753  LogicalResult
754  matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
755  ConversionPatternRewriter &rewriter) const override {
756  // If numGroup is not present pick 0 as a conservative correct value.
757  int32_t numGroups = adaptor.getNumGroups().value_or(0);
758  rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
759  rewriter.eraseOp(op);
760  return success();
761  }
762 };
763 
764 /// Creates mbarrier object in shared memory
765 struct NVGPUMBarrierCreateLowering
766  : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
768 
769  template <typename moduleT>
770  memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
771  Operation *funcOp, moduleT moduleOp,
772  MemRefType barrierType) const {
773  SymbolTable symbolTable(moduleOp);
774  OpBuilder::InsertionGuard guard(rewriter);
775  rewriter.setInsertionPoint(&moduleOp.front());
776  auto global = rewriter.create<memref::GlobalOp>(
777  funcOp->getLoc(), "__mbarrier",
778  /*sym_visibility=*/rewriter.getStringAttr("private"),
779  /*type=*/barrierType,
780  /*initial_value=*/ElementsAttr(),
781  /*constant=*/false,
782  /*alignment=*/rewriter.getI64IntegerAttr(8));
783  symbolTable.insert(global);
784  return global;
785  }
786 
787  LogicalResult
788  matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
789  ConversionPatternRewriter &rewriter) const override {
790  Operation *funcOp = op->getParentOp();
791  MemRefType barrierType = nvgpu::getMBarrierMemrefType(
792  rewriter.getContext(), op.getBarriers().getType());
793 
794  memref::GlobalOp global;
795  if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
796  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
797  else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
798  global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
799 
800  rewriter.setInsertionPoint(op);
801  rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
802  global.getName());
803  return success();
804  }
805 };
806 
807 /// Base class for lowering mbarrier operations to nvvm intrinsics.
808 template <typename SourceOp>
809 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
810 public:
812  /// Returns the base pointer of the mbarrier object.
813  Value getMbarrierPtr(ImplicitLocOpBuilder &b,
814  nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
815  Value mbarId,
816  ConversionPatternRewriter &rewriter) const {
817  MemRefType mbarrierMemrefType =
818  nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
820  b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
821  }
822 };
823 
824 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
825 struct NVGPUMBarrierInitLowering
826  : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
827  using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
828 
829  LogicalResult
830  matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
831  ConversionPatternRewriter &rewriter) const override {
832  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
833  nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
834  rewriter.setInsertionPoint(op);
835  Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
836  adaptor.getMbarId(), rewriter);
837  Value count = truncToI32(b, adaptor.getCount());
838  if (isMbarrierShared(mbarrierType)) {
839  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
840  op, barrier, count, adaptor.getPredicate());
841  } else {
842  rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
843  adaptor.getPredicate());
844  }
845  return success();
846  }
847 };
848 
849 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
850 struct NVGPUMBarrierArriveLowering
851  : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
852  using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
853  LogicalResult
854  matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
855  ConversionPatternRewriter &rewriter) const override {
856  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
857  Value barrier =
858  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
859  adaptor.getMbarId(), rewriter);
860  Type tokenType = getTypeConverter()->convertType(
862  if (isMbarrierShared(op.getBarriers().getType())) {
863  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
864  barrier);
865  } else {
866  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
867  barrier);
868  }
869  return success();
870  }
871 };
872 
873 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
874 /// `nvvm.mbarrier.arrive.nocomplete`
875 struct NVGPUMBarrierArriveNoCompleteLowering
876  : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
877  using MBarrierBasePattern<
878  nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
879  LogicalResult
880  matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
881  ConversionPatternRewriter &rewriter) const override {
882  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
883  Value barrier =
884  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
885  adaptor.getMbarId(), rewriter);
886  Type tokenType = getTypeConverter()->convertType(
888  Value count = truncToI32(b, adaptor.getCount());
889  if (isMbarrierShared(op.getBarriers().getType())) {
890  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
891  op, tokenType, barrier, count);
892  } else {
893  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
894  op, tokenType, barrier, count);
895  }
896  return success();
897  }
898 };
899 
900 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
901 struct NVGPUMBarrierTestWaitLowering
902  : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
903  using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
904  LogicalResult
905  matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
906  ConversionPatternRewriter &rewriter) const override {
907  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
908  Value barrier =
909  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
910  adaptor.getMbarId(), rewriter);
911  Type retType = rewriter.getI1Type();
912  if (isMbarrierShared(op.getBarriers().getType())) {
913  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
914  op, retType, barrier, adaptor.getToken());
915  } else {
916  rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
917  op, retType, barrier, adaptor.getToken());
918  }
919  return success();
920  }
921 };
922 
923 struct NVGPUMBarrierArriveExpectTxLowering
924  : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
925  using MBarrierBasePattern<
926  nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
927  LogicalResult
928  matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
929  ConversionPatternRewriter &rewriter) const override {
930  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
931  Value barrier =
932  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
933  adaptor.getMbarId(), rewriter);
934  Value txcount = truncToI32(b, adaptor.getTxcount());
935 
936  if (isMbarrierShared(op.getBarriers().getType())) {
937  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
938  op, barrier, txcount, adaptor.getPredicate());
939  return success();
940  }
941 
942  rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
943  op, barrier, txcount, adaptor.getPredicate());
944  return success();
945  }
946 };
947 
948 struct NVGPUMBarrierTryWaitParityLowering
949  : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
950  using MBarrierBasePattern<
951  nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
952  LogicalResult
953  matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
954  ConversionPatternRewriter &rewriter) const override {
955  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
956  Value barrier =
957  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
958  adaptor.getMbarId(), rewriter);
959  Value ticks = truncToI32(b, adaptor.getTicks());
960  Value phase =
961  b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
962 
963  if (isMbarrierShared(op.getBarriers().getType())) {
964  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
965  op, barrier, phase, ticks);
966  return success();
967  }
968 
969  rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
970  phase, ticks);
971  return success();
972  }
973 };
974 
975 struct NVGPUTmaAsyncLoadOpLowering
976  : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
977  using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
978  LogicalResult
979  matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
980  ConversionPatternRewriter &rewriter) const override {
981  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
982  auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
983  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
984  adaptor.getDst(), {}, rewriter);
985  Value barrier =
986  getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
987  adaptor.getMbarId(), rewriter);
988 
989  SmallVector<Value> coords = adaptor.getCoordinates();
990  for (auto [index, value] : llvm::enumerate(coords)) {
991  coords[index] = truncToI32(b, value);
992  }
993  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
994  op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
995  ValueRange{}, adaptor.getMulticastMask(), Value{},
996  adaptor.getPredicate());
997  return success();
998  }
999 };
1000 
1001 struct NVGPUTmaAsyncStoreOpLowering
1002  : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1003  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1004  LogicalResult
1005  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1006  ConversionPatternRewriter &rewriter) const override {
1007  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1008  auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1009  Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1010  adaptor.getSrc(), {}, rewriter);
1011  SmallVector<Value> coords = adaptor.getCoordinates();
1012  for (auto [index, value] : llvm::enumerate(coords)) {
1013  coords[index] = truncToI32(b, value);
1014  }
1015 
1016  rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1017  op, adaptor.getTensorMapDescriptor(), dest, coords,
1018  adaptor.getPredicate());
1019  return success();
1020  }
1021 };
1022 
1023 struct NVGPUGenerateWarpgroupDescriptorLowering
1024  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1025  using ConvertOpToLLVMPattern<
1026  nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1027 
1028  LogicalResult
1029  matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1030  ConversionPatternRewriter &rewriter) const override {
1031 
1032  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1033 
1034  nvgpu::TensorMapSwizzleKind swizzleKind =
1035  op.getTensorMap().getType().getSwizzle();
1036 
1037  unsigned layout =
1038  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1039  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1040  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1041  : 1;
1042  unsigned swizzle =
1043  (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1044  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1045  : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1046  : 0;
1047 
1048  auto ti64 = b.getIntegerType(64);
1049  auto makeConst = [&](uint64_t index) -> Value {
1050  return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1051  };
1052  auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1053  return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1054  };
1055  auto shiftRight = [&](Value value, unsigned shift) -> Value {
1056  return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1057  };
1058  auto insertBit = [&](Value desc, Value val, int startBit) {
1059  return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1060  };
1061 
1062  int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1063  uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1064  uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1065  uint64_t offsetVal = 0;
1066 
1067  Value strideDim = makeConst(strideDimVal);
1068  Value leadDim = makeConst(leadDimVal);
1069 
1070  Value baseAddr = getStridedElementPtr(
1071  op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1072  adaptor.getTensor(), {}, rewriter);
1073  Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1074  // Just use 14 bits for base address
1075  Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1076 
1077  int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1078  startLeadBit = 16, startBaseAddrBit = 0;
1079  Value dsc = makeConst(0);
1080  // // [62,64) swizzle type
1081  dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1082  // // [49,52) base_offset
1083  dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1084  // // [32,46) stride
1085  dsc = insertBit(dsc, strideDim, startStrideBit);
1086  // // [16,30) leading dimension
1087  dsc = insertBit(dsc, leadDim, startLeadBit);
1088  // // [0,14) start_address
1089  dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1090 
1091  LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1092  << "leading_off:" << leadDimVal << "\t"
1093  << "stride_off :" << strideDimVal << "\t"
1094  << "base_offset:" << offsetVal << "\t"
1095  << "layout_type:" << swizzle << " ("
1096  << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1097  << ")\n start_addr : " << baseAddr << "\n");
1098 
1099  rewriter.replaceOp(op, dsc);
1100  return success();
1101  }
1102 };
1103 
1104 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1105  return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1106  b.getI32IntegerAttr(index));
1107 }
1108 
1109 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1110 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1111  // Enum is from CUDA driver API
1112  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1113  enum CUtensorMapDataTypeEnum {
1114  CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1115  CU_TENSOR_MAP_DATA_TYPE_UINT16,
1116  CU_TENSOR_MAP_DATA_TYPE_UINT32,
1117  CU_TENSOR_MAP_DATA_TYPE_INT32,
1118  CU_TENSOR_MAP_DATA_TYPE_UINT64,
1119  CU_TENSOR_MAP_DATA_TYPE_INT64,
1120  CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1121  CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1122  CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1123  CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1124  CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1125  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1126  CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1127  };
1128 
1129  if (type.isUnsignedInteger(8))
1130  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1131  if (type.isUnsignedInteger(16))
1132  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1133  if (type.isUnsignedInteger(32))
1134  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1135  if (type.isUnsignedInteger(64))
1136  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1137  if (type.isSignlessInteger(32))
1138  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1139  if (type.isSignlessInteger(64))
1140  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1141  if (type.isF16())
1142  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1143  if (type.isF32())
1144  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1145  if (type.isF64())
1146  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1147  if (type.isBF16())
1148  return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1149 
1150  llvm_unreachable("Not supported data type");
1151 }
1152 
1153 struct NVGPUTmaCreateDescriptorOpLowering
1154  : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1155  using ConvertOpToLLVMPattern<
1156  nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1157  LogicalResult
1158  matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1159  ConversionPatternRewriter &rewriter) const override {
1160  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1161  auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1162  Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1163 
1164  Value tensorElementType =
1165  elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1166  auto promotedOperands = getTypeConverter()->promoteOperands(
1167  b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1168 
1169  Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1170  makeI64Const(b, 5));
1171  for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1172  Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1173  boxArrayPtr, makeI64Const(b, index));
1174  b.create<LLVM::StoreOp>(value, gep);
1175  }
1176 
1177  nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1178  // Set Arguments for the function call
1179  SmallVector<Value> arguments;
1180  arguments.push_back(promotedOperands[0]); // rank
1181  arguments.push_back(promotedOperands[1]); // descriptor
1182  arguments.push_back(tensorElementType); // data type
1183  arguments.push_back(
1184  makeI64Const(b, (int)desc.getInterleave())); // interleave
1185  arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
1186  arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
1187  arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
1188  arguments.push_back(boxArrayPtr); // box dimensions
1189 
1190  // Set data types of the arguments
1191  SmallVector<Type> argTypes = {
1192  llvmInt64Type, /* int64_t tensorRank */
1193  llvmPointerType, /* ptr */
1194  llvmInt64Type, /* int64_t */
1195  llvmInt64Type, /* int64_t */
1196  llvmInt64Type, /* int64_t */
1197  llvmInt64Type, /* int64_t */
1198  llvmInt64Type, /* int64_t */
1199  llvmPointerType /* ptr */
1200  };
1201  FunctionCallBuilder hostRegisterCallBuilder = {
1202  "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1203  Value tensorMap =
1204  hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1205 
1206  rewriter.replaceOp(op, tensorMap);
1207  return success();
1208  }
1209 };
1210 
1211 struct NVGPUWarpgroupMmaOpLowering
1212  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1214 
1215  /// This is a helper class to generate required NVVM Ops for warp-group level
1216  /// matrix multiplication.
1217  /// When the given GEMM shape is larger than the shape of
1218  /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1219  /// Op(s), group and execute them asynchronously. The class also handles
1220  /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1221  /// create descriptors for each instruction.
1222  ///
1223  /// For example this is the case when the shape of GEMM is 128x128x128
1224  ///
1225  /// nvvm.wgmma.fence.aligned
1226  ///
1227  /// nvvm.wgmma.mma.async descA, descB
1228  /// iterate(descA, descB)
1229  /// nvvm.wgmma.mma.async descA, descB
1230  /// [6x times more]
1231  ///
1232  /// nvvm.wgmma.group.sync.aligned
1233  /// nvvm.wgmma.wait.group.sync [groupId]
1234  ///
1235  class WarpgroupGemm {
1236  nvgpu::WarpgroupMmaOp op;
1238  OpAdaptor adaptor;
1239 
1240  // Entire shape of the given Op
1241  int64_t totalM, totalN, totalK;
1242 
1243  // Shape of one wgmma instruction
1244  int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1245 
1246  // Iteration counts for GEMM
1247  int iterationM = 0, iterationN = 0, iterationK = 0;
1248 
1249  /// The function returns the shape of wgmma instruction that is defined in
1250  /// PTX programming guide.
1251  /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1252  void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1253  wgmmaM = 64;
1254  wgmmaN = sizeN;
1255  if (inputElemType.isTF32()) {
1256  wgmmaK = 8;
1257  } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1258  wgmmaK = 16;
1259  } else if (inputElemType.isFloat8E4M3FN() ||
1260  inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
1261  wgmmaK = 32;
1262  } else if (inputElemType.isInteger(1)) {
1263  wgmmaK = 256;
1264  } else {
1265  llvm_unreachable("msg: not supported K shape");
1266  }
1267  LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1268  << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1269  }
1270 
1271  /// Generates WGMMATypesAttr from MLIR Type
1272  NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1273  bool useF32 = false) const {
1274  auto getWgmmaType = [=](Type elemType) {
1275  if (elemType.isF32() || elemType.isTF32())
1276  return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1277  if (elemType.isF16())
1278  return NVVM::WGMMATypes::f16;
1279  if (elemType.isBF16())
1280  return NVVM::WGMMATypes::bf16;
1281  if (elemType.isFloat8E4M3FN())
1282  return NVVM::WGMMATypes::e4m3;
1283  if (elemType.isFloat8E5M2())
1284  return NVVM::WGMMATypes::e5m2;
1285  if (elemType.isInteger(1))
1286  return NVVM::WGMMATypes::b1;
1287  if (elemType.isInteger(8))
1288  return NVVM::WGMMATypes::s8;
1289  if (elemType.isUnsignedInteger(8))
1290  return NVVM::WGMMATypes::u8;
1291  if (elemType.isInteger(32))
1292  return NVVM::WGMMATypes::s32;
1293  llvm_unreachable("unsupported type");
1294  };
1295  return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1296  }
1297 
1298  /// Generates layout attribute for the input matrix for wgmma instruction
1299  NVVM::MMALayoutAttr
1300  generateWgmmaLayout(std::optional<bool> transpose) const {
1301  if (transpose.value_or(false))
1302  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1303  return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1304  }
1305 
1306  /// Generates shape attribute for wgmma instruction
1307  NVVM::MMAShapeAttr generateWgmmaShape() const {
1308  return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1309  }
1310 
1311  /// Generates scale attributes of output matrix for wgmma instruction
1312  NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1314  NVVM::WGMMAScaleOut::one);
1315  }
1316  /// Generates scale attributes of input matrix for wgmma instruction
1317  NVVM::WGMMAScaleInAttr generateScaleIn() const {
1319  NVVM::WGMMAScaleIn::one);
1320  }
1321 
1322  /// Basic function to generate Add
1323  Value makeAdd(Value lhs, Value rhs) {
1324  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1325  };
1326 
1327  /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1328  /// Currently, it only handles row-major.
1329  ///
1330  /// It moves the pointer like below for [128][64] size:
1331  /// +2 +4 +6
1332  /// ↓ ↓ ↓
1333  /// descA ---> +--+--+--+--+
1334  /// |->|->|->|->|
1335  /// | | | | |
1336  /// | | | | |
1337  /// | | | | |
1338  /// descA+512---> +-----------+
1339  /// | | | | |
1340  /// | | | | |
1341  /// | | | | |
1342  /// | | | | |
1343  /// +-----------+
1344  ///
1345  Value iterateDescriptorA(Value desc, int i, int j, int k) {
1346  MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1347  Type elemA = matrixTypeA.getElementType();
1348  int byte = elemA.getIntOrFloatBitWidth() / 8;
1349  int tileShapeA = matrixTypeA.getDimSize(1);
1350  int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1351  incrementVal = incrementVal >> exclude4LSB;
1352  LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1353  << "] [wgmma descriptors] Descriptor A + "
1354  << incrementVal << " | \t ");
1355  if (!incrementVal)
1356  return desc;
1357  return makeAdd(desc, makeI64Const(b, incrementVal));
1358  }
1359 
1360  /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1361  /// Currently, it only handles column-major.
1362  ///
1363  /// It moves the pointer like below for [128][64] size:
1364  /// descB ---> +--+--+--+--+--+--+--+--+
1365  /// |↓ | | | | | | | |
1366  /// |↓ | | | | | | | |
1367  /// |↓ | | | | | | | |
1368  /// |↓ | | | | | | | |
1369  /// +--+--+--+--+--+--+--+--+
1370  ///
1371  Value iterateDescriptorB(Value desc, int i, int j, int k) {
1372  MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1373  Type elemB = matrixTypeB.getElementType();
1374  int byte = elemB.getIntOrFloatBitWidth() / 8;
1375  int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1376  incrementVal = incrementVal >> exclude4LSB;
1377  LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1378  if (!incrementVal)
1379  return desc;
1380  return makeAdd(desc, makeI64Const(b, incrementVal));
1381  }
1382 
1383  /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1384  /// descriptors and arranges them based on induction variables: i, j, and k.
1385  Value generateWgmma(int i, int j, int k, Value matrixC) {
1386  LLVM_DEBUG(DBGS() << "\t wgmma."
1387  << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1388  << "(A[" << (iterationM * wgmmaM) << ":"
1389  << (iterationM * wgmmaM) + wgmmaM << "]["
1390  << (iterationK * wgmmaK) << ":"
1391  << (iterationK * wgmmaK + wgmmaK) << "] * "
1392  << " B[" << (iterationK * wgmmaK) << ":"
1393  << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1394  << wgmmaN << "])\n");
1395 
1396  Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1397  Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1398 
1399  Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1400  NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1401 
1402  Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1403  NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1404 
1405  Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1406  NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1407 
1408  NVVM::MMAShapeAttr shape = generateWgmmaShape();
1409  NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1410  NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1411  NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1412  NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1413 
1414  auto overflow = NVVM::MMAIntOverflowAttr::get(
1415  op->getContext(), NVVM::MMAIntOverflow::wrapped);
1416 
1417  return b.create<NVVM::WgmmaMmaAsyncOp>(
1418  matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1419  itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1420  overflow);
1421  }
1422 
1423  /// Generates multiple wgmma instructions to complete the given GEMM shape
1424  Value generateWgmmaGroup() {
1425  Value wgmmaResult =
1426  b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1427 
1428  // Perform GEMM
1429  SmallVector<Value> wgmmaResults;
1430  for (int i = 0; i < iterationM; ++i) {
1431  Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1432  for (int j = 0; j < iterationN; ++j)
1433  for (int k = 0; k < iterationK; ++k)
1434  matrixC = generateWgmma(i, j, k, matrixC);
1435  wgmmaResults.push_back(matrixC);
1436  }
1437  for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1438  wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1439  wgmmaResult, matrix, idx);
1440  }
1441  return wgmmaResult;
1442  }
1443 
1444  public:
1445  WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1446  OpAdaptor adaptor)
1447  : op(op), b(b), adaptor(adaptor) {
1448  // Find the entire GEMM Shape
1449  totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1450  totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1451  totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1452  LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1453  << "] += A[" << totalM << "][" << totalK << "] * B["
1454  << totalK << "][" << totalN << "] ---===\n");
1455 
1456  // Find the shape for one wgmma instruction
1457  findWgmmaShape(
1458  totalM, totalN,
1459  op.getDescriptorA().getType().getTensor().getElementType());
1460 
1461  // Iterations counts to complete the given shape with wgmma shape
1462  iterationM = totalM / wgmmaM;
1463  iterationN = totalN / wgmmaN;
1464  iterationK = totalK / wgmmaK;
1465  }
1466 
1467  /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1468  /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1469  /// instructions and group synchronization, as well as waiting
1470  /// (WgmmaGroupSyncAlignedOp) for group synchronization
1471  /// (WgmmaWaitGroupSyncOp) after the instructions.
1472  Value generateWarpgroupMma() {
1473  b.create<NVVM::WgmmaFenceAlignedOp>();
1474  Value wgmmaResult = generateWgmmaGroup();
1475  b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1476  b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1477  return wgmmaResult;
1478  }
1479  };
1480  LogicalResult
1481  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1482  ConversionPatternRewriter &rewriter) const override {
1483  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1484 
1485  // Step 1. Build a helper class
1486  WarpgroupGemm warpgroupGemm(op, b, adaptor);
1487 
1488  // Step 2. Get the entire GEMM Shape
1489  Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1490 
1491  // Step 3. Replace fragmented result struct with the op results
1492  rewriter.replaceOp(op, wgmmaResult);
1493  return success();
1494  }
1495 };
1496 
1497 struct NVGPUWarpgroupMmaStoreOpLowering
1498  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1499  using ConvertOpToLLVMPattern<
1500  nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1501 
1502  /// This function stores a fragmented register matrix owned by a warp group
1503  /// (128 threads) into a memref. Each thread has 64 registers, each the size
1504  /// of a struct.
1505  /// Here is what each threads (T) holds, each `d` is struct value with a
1506  /// number.
1507  ///
1508  /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1509  /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1510  /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1511  /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1512  /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1513  ///
1514  /// Matrix-D:
1515  /// +______________________________________________________________________+
1516  /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1517  /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1518  /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1519  /// ..| .........|.........|.........|.........|........|...........|........|
1520  /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1521  /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1522  /// ..| .........|.........|.........|.........|........|...........|........|
1523  /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1524  /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1525  /// ..| .........|.........|.........|.........|........|...........|........|
1526  /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1527  /// ..| .........|.........|.........|.........|........|...........|........|
1528  /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1529  /// ..| .........|.........|.........|.........|........|...........|........|
1530  /// +______________________________________________________________________+
1531  ///
1532  /// \param rewriter: The pattern rewriter.
1533  /// \param matrixD: Result of the warp-group MMA operation (fragmented
1534  /// matrix). It is holded by a thread and a struct with 64 elements.
1535  /// \param dstMemref: The memref where the registers will be stored.
1536  /// \param offset: the offset within the memref where the registers will be
1537  /// stored.
1538  void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1539  TypedValue<MemRefType> dstMemref,
1540  int offset) const {
1541  Type i32 = b.getI32Type();
1542 
1543  auto makeConst = [&](int32_t index) -> Value {
1544  return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1545  };
1546  Value c1 = makeConst(1);
1547  Value c2 = makeConst(2);
1548  Value c4 = makeConst(4);
1549  Value c8 = makeConst(8);
1550  Value c16 = makeConst(16);
1551  Value warpSize = makeConst(kWarpSize);
1552 
1553  auto makeMul = [&](Value lhs, Value rhs) -> Value {
1554  return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1555  };
1556  auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1557  return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1558  };
1559 
1560  auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1562  Type it = b.getIndexType();
1563  Value idx = b.create<arith::IndexCastOp>(it, x);
1564  Value idy0 = b.create<arith::IndexCastOp>(it, y);
1565  Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1566  Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1567  Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1568  b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1569  b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1570  };
1571 
1572  Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1573  Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1574  Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1575  Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1576  Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1577 
1578  Value tj = makeMul(lane4modId, c2);
1579  Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1580  if (offset)
1581  ti = makeAdd(ti, makeConst(offset));
1582 
1583  auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1584 
1585  // Number of 32-bit registers owns per thread
1586  constexpr unsigned numAdjacentRegisters = 2;
1587  // Number of 8x8 matrices one below another per warp
1588  constexpr unsigned numStackedMatrices = 2;
1589 
1590  size_t storeCount = (structType.getBody().size() /
1591  (numStackedMatrices * numAdjacentRegisters));
1592 
1593  for (size_t i = 0; i < numStackedMatrices; ++i) {
1594  Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1595  for (size_t j = 0; j < storeCount; ++j) {
1596  Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1597  size_t structIndex = (i * numAdjacentRegisters) +
1598  (j * (numStackedMatrices * numAdjacentRegisters));
1599  makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1600  }
1601  }
1602  }
1603 
1604  LogicalResult
1605  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1606  ConversionPatternRewriter &rewriter) const override {
1607  int offset = 0;
1608  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1609  Value matriDValue = adaptor.getMatrixD();
1610  auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1611  for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1612  auto structType = cast<LLVM::LLVMStructType>(matrixD);
1613  Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1614  storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1615  offset += structType.getBody().size();
1616  }
1617  rewriter.eraseOp(op);
1618  return success();
1619  }
1620 };
1621 
1622 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1623  : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1624  using ConvertOpToLLVMPattern<
1625  nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1626  LogicalResult
1627  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1628  ConversionPatternRewriter &rewriter) const override {
1629  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1630  LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1631  getTypeConverter()->convertType(op.getMatrixC().getType()));
1632  Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1633  .getBody()
1634  .front();
1635  Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1636  Value packStruct = b.create<LLVM::UndefOp>(packStructType);
1637  SmallVector<Value> innerStructs;
1638  // Unpack the structs and set all values to zero
1639  for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1640  auto structType = cast<LLVM::LLVMStructType>(s);
1641  Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1642  for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1643  structValue = b.create<LLVM::InsertValueOp>(
1644  structType, structValue, zero, ArrayRef<int64_t>({i}));
1645  }
1646  innerStructs.push_back(structValue);
1647  }
1648  // Pack the inner structs into a single struct
1649  for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1650  packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1651  packStruct, matrix, idx);
1652  }
1653  rewriter.replaceOp(op, packStruct);
1654  return success();
1655  }
1656 };
1657 
1658 struct NVGPUTmaPrefetchOpLowering
1659  : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1661  LogicalResult
1662  matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1663  ConversionPatternRewriter &rewriter) const override {
1664  rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1665  op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1666  return success();
1667  }
1668 };
1669 
1670 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1672  LogicalResult
1673  matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1674  ConversionPatternRewriter &rewriter) const override {
1675  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1676  auto i64Ty = b.getI64Type();
1677  auto f32Ty = b.getF32Type();
1678  VectorType inTy = op.getIn().getType();
1679  // apply rcp.approx.ftz.f on each element in vector.
1680  auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1681  Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy);
1682  int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1683  for (int i = 0; i < numElems; i++) {
1684  Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
1685  Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
1686  Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1687  ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1688  }
1689  return ret1DVec;
1690  };
1691  if (inTy.getRank() == 1) {
1692  rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1693  return success();
1694  }
1696  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1697  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1698  OpAdaptor adaptor(operands);
1699  return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1700  },
1701  rewriter);
1702  }
1703 };
1704 } // namespace
1705 
1707  RewritePatternSet &patterns) {
1708  patterns.add<
1709  NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1710  NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1711  NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1712  NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1713  NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1714  NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1715  NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1716  NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1717  NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1718  NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1719  NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1720  NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1721  NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1722  NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1723  NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1724  MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1725  NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1726  NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
1727 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:28
constexpr int kWarpSize
Definition: NVGPUDialect.h:25
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
Definition: NVGPUToNVVM.cpp:51
#define DBGSE()
Definition: NVGPUToNVVM.cpp:36
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
Definition: NVGPUToNVVM.cpp:61
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
Definition: NVGPUToNVVM.cpp:47
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
#define DBGS()
Definition: NVGPUToNVVM.cpp:35
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI16Type()
Definition: Builders.cpp:97
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:232
FloatType getF32Type()
Definition: Builders.cpp:79
IntegerType getI64Type()
Definition: Builders.cpp:101
IntegerType getI32Type()
Definition: Builders.cpp:99
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:144
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:103
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:285
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:347
FloatType getF16Type()
Definition: Builders.cpp:75
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:89
IndexType getIndexType()
Definition: Builders.cpp:87
IntegerType getI8Type()
Definition: Builders.cpp:95
FloatType getF64Type()
Definition: Builders.cpp:81
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
Definition: Pattern.cpp:61
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:108
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
Definition: LLVMTypes.cpp:489
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:452
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:354
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:404
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:480
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:56
bool isTF32() const
Definition: Types.cpp:54
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isFloat8E4M3FN() const
Definition: Types.cpp:41
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:71
bool isF32() const
Definition: Types.cpp:55
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:95
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:62
bool isFloat8E5M2() const
Definition: Types.cpp:39
bool isF16() const
Definition: Types.cpp:53
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:129
bool isBF16() const
Definition: Types.cpp:52
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:953
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.