MLIR  22.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1 //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate NVVMIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
33 
34 #include "../GPUCommon/GPUOpsLowering.h"
35 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
36 #include "../GPUCommon/OpToFuncCallLowering.h"
37 #include <optional>
38 
39 namespace mlir {
40 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
41 #include "mlir/Conversion/Passes.h.inc"
42 } // namespace mlir
43 
44 using namespace mlir;
45 
46 namespace {
47 
48 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
49 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
50  switch (mode) {
51  case gpu::ShuffleMode::XOR:
52  return NVVM::ShflKind::bfly;
53  case gpu::ShuffleMode::UP:
54  return NVVM::ShflKind::up;
55  case gpu::ShuffleMode::DOWN:
56  return NVVM::ShflKind::down;
57  case gpu::ShuffleMode::IDX:
58  return NVVM::ShflKind::idx;
59  }
60  llvm_unreachable("unknown shuffle mode");
61 }
62 
63 static std::optional<NVVM::ReduxKind>
64 convertReduxKind(gpu::AllReduceOperation mode) {
65  switch (mode) {
66  case gpu::AllReduceOperation::ADD:
67  return NVVM::ReduxKind::ADD;
68  case gpu::AllReduceOperation::MUL:
69  return std::nullopt;
70  case gpu::AllReduceOperation::MINSI:
71  return NVVM::ReduxKind::MIN;
73  return std::nullopt;
74  case gpu::AllReduceOperation::MINNUMF:
75  return NVVM::ReduxKind::MIN;
76  case gpu::AllReduceOperation::MAXSI:
77  return NVVM::ReduxKind::MAX;
78  case gpu::AllReduceOperation::MAXUI:
79  return std::nullopt;
80  case gpu::AllReduceOperation::MAXNUMF:
81  return NVVM::ReduxKind::MAX;
82  case gpu::AllReduceOperation::AND:
83  return NVVM::ReduxKind::AND;
84  case gpu::AllReduceOperation::OR:
85  return NVVM::ReduxKind::OR;
86  case gpu::AllReduceOperation::XOR:
87  return NVVM::ReduxKind::XOR;
88  case gpu::AllReduceOperation::MINIMUMF:
89  case gpu::AllReduceOperation::MAXIMUMF:
90  return std::nullopt;
91  }
92  return std::nullopt;
93 }
94 
95 /// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
96 /// must be run by the entire subgroup, otherwise it is undefined behaviour.
97 struct GPUSubgroupReduceOpLowering
98  : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
100  LogicalResult
101 
102  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
103  ConversionPatternRewriter &rewriter) const override {
104  if (op.getClusterSize())
105  return rewriter.notifyMatchFailure(
106  op, "lowering for clustered reduce not implemented");
107 
108  if (!op.getUniform())
109  return rewriter.notifyMatchFailure(
110  op, "cannot be lowered to redux as the op must be run "
111  "uniformly (entire subgroup).");
112  if (!op.getValue().getType().isInteger(32))
113  return rewriter.notifyMatchFailure(op, "unsupported data type");
114 
115  std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
116  if (!mode.has_value())
117  return rewriter.notifyMatchFailure(
118  op, "unsupported reduction mode for redux");
119 
120  Location loc = op->getLoc();
121  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
122  Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
123 
124  auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
125  op.getValue(), mode.value(), offset);
126 
127  rewriter.replaceOp(op, reduxOp->getResult(0));
128  return success();
129  }
130 };
131 
132 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
134 
135  /// Lowers a shuffle to the corresponding NVVM op.
136  ///
137  /// Convert the `width` argument into an activeMask (a bitmask which specifies
138  /// which threads participate in the shuffle) and a maskAndClamp (specifying
139  /// the highest lane which participates in the shuffle).
140  ///
141  /// %one = llvm.constant(1 : i32) : i32
142  /// %minus_one = llvm.constant(-1 : i32) : i32
143  /// %thirty_two = llvm.constant(32 : i32) : i32
144  /// %num_lanes = llvm.sub %thirty_two, %width : i32
145  /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
146  /// %mask_and_clamp = llvm.sub %width, %one : i32
147  /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
148  /// %mask_and_clamp : !llvm<"{ float, i1 }">
149  /// %shfl_value = llvm.extractvalue %shfl[0] :
150  /// !llvm<"{ float, i1 }">
151  /// %shfl_pred = llvm.extractvalue %shfl[1] :
152  /// !llvm<"{ float, i1 }">
153  LogicalResult
154  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
155  ConversionPatternRewriter &rewriter) const override {
156  Location loc = op->getLoc();
157 
158  auto valueTy = adaptor.getValue().getType();
159  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
160  auto predTy = IntegerType::get(rewriter.getContext(), 1);
161 
162  Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
163  Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
164  Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
165  Value numLeadInactiveLane = LLVM::SubOp::create(
166  rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
167  // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
168  Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
169  numLeadInactiveLane);
170  Value maskAndClamp;
171  if (op.getMode() == gpu::ShuffleMode::UP) {
172  // Clamp lane: `32 - activeWidth`
173  maskAndClamp = numLeadInactiveLane;
174  } else {
175  // Clamp lane: `activeWidth - 1`
176  maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
177  adaptor.getWidth(), one);
178  }
179 
180  bool predIsUsed = !op->getResult(1).use_empty();
181  UnitAttr returnValueAndIsValidAttr = nullptr;
182  Type resultTy = valueTy;
183  if (predIsUsed) {
184  returnValueAndIsValidAttr = rewriter.getUnitAttr();
185  resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
186  {valueTy, predTy});
187  }
188  Value shfl = NVVM::ShflOp::create(
189  rewriter, loc, resultTy, activeMask, adaptor.getValue(),
190  adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
191  returnValueAndIsValidAttr);
192  if (predIsUsed) {
193  Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
194  Value isActiveSrcLane =
195  LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
196  rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
197  } else {
198  rewriter.replaceOp(op, {shfl, nullptr});
199  }
200  return success();
201  }
202 };
203 
204 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
206 
207  LogicalResult
208  matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
209  ConversionPatternRewriter &rewriter) const override {
210  auto loc = op->getLoc();
211  MLIRContext *context = rewriter.getContext();
212  LLVM::ConstantRangeAttr bounds = nullptr;
213  if (std::optional<APInt> upperBound = op.getUpperBound())
214  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
215  /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
216  else
217  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
218  /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
219  Value newOp =
220  NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
221  // Truncate or extend the result depending on the index bitwidth specified
222  // by the LLVMTypeConverter options.
223  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
224  if (indexBitwidth > 32) {
225  newOp = LLVM::SExtOp::create(
226  rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
227  } else if (indexBitwidth < 32) {
228  newOp = LLVM::TruncOp::create(
229  rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
230  }
231  rewriter.replaceOp(op, {newOp});
232  return success();
233  }
234 };
235 
236 /// Lowering of cf.assert into a conditional __assertfail.
237 struct AssertOpToAssertfailLowering
238  : public ConvertOpToLLVMPattern<cf::AssertOp> {
240 
241  LogicalResult
242  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
243  ConversionPatternRewriter &rewriter) const override {
244  MLIRContext *ctx = rewriter.getContext();
245  Location loc = assertOp.getLoc();
246  Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
247  Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
248  Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
249  Type ptrType = LLVM::LLVMPointerType::get(ctx);
250  Type voidType = LLVM::LLVMVoidType::get(ctx);
251 
252  // Find or create __assertfail function declaration.
253  auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
254  auto assertfailType = LLVM::LLVMFunctionType::get(
255  voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
256  LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
257  moduleOp, loc, rewriter, "__assertfail", assertfailType);
258  assertfailDecl.setPassthroughAttr(
259  ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
260 
261  // Split blocks and insert conditional branch.
262  // ^before:
263  // ...
264  // cf.cond_br %condition, ^after, ^assert
265  // ^assert:
266  // cf.assert
267  // cf.br ^after
268  // ^after:
269  // ...
270  Block *beforeBlock = assertOp->getBlock();
271  Block *assertBlock =
272  rewriter.splitBlock(beforeBlock, assertOp->getIterator());
273  Block *afterBlock =
274  rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
275  rewriter.setInsertionPointToEnd(beforeBlock);
276  cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
277  assertBlock);
278  rewriter.setInsertionPointToEnd(assertBlock);
279  cf::BranchOp::create(rewriter, loc, afterBlock);
280 
281  // Continue cf.assert lowering.
282  rewriter.setInsertionPoint(assertOp);
283 
284  // Populate file name, file number and function name from the location of
285  // the AssertOp.
286  StringRef fileName = "(unknown)";
287  StringRef funcName = "(unknown)";
288  int32_t fileLine = 0;
289  while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
290  loc = callSiteLoc.getCallee();
291  if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
292  fileName = fileLineColLoc.getFilename().strref();
293  fileLine = fileLineColLoc.getStartLine();
294  } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
295  funcName = nameLoc.getName().strref();
296  if (auto fileLineColLoc =
297  dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
298  fileName = fileLineColLoc.getFilename().strref();
299  fileLine = fileLineColLoc.getStartLine();
300  }
301  }
302 
303  // Create constants.
304  auto getGlobal = [&](LLVM::GlobalOp global) {
305  // Get a pointer to the format string's first element.
306  Value globalPtr = LLVM::AddressOfOp::create(
307  rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
308  global.getSymNameAttr());
309  Value start =
310  LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
311  globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
312  return start;
313  };
314  Value assertMessage = getGlobal(getOrCreateStringConstant(
315  rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
316  Value assertFile = getGlobal(getOrCreateStringConstant(
317  rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
318  Value assertFunc = getGlobal(getOrCreateStringConstant(
319  rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
320  Value assertLine =
321  LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
322  Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
323 
324  // Insert function call to __assertfail.
325  SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
326  assertFunc, c1};
327  rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
328  arguments);
329  return success();
330  }
331 };
332 
333 /// Import the GPU Ops to NVVM Patterns.
334 #include "GPUToNVVM.cpp.inc"
335 
336 /// A pass that replaces all occurrences of GPU device operations with their
337 /// corresponding NVVM equivalent.
338 ///
339 /// This pass only handles device code and is not meant to be run on GPU host
340 /// code.
341 struct LowerGpuOpsToNVVMOpsPass final
342  : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
343  using Base::Base;
344 
345  void getDependentDialects(DialectRegistry &registry) const override {
346  Base::getDependentDialects(registry);
348  }
349 
350  void runOnOperation() override {
351  gpu::GPUModuleOp m = getOperation();
352 
353  // Request C wrapper emission.
354  for (auto func : m.getOps<func::FuncOp>()) {
355  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
357  }
358 
359  // Customize the bitwidth used for the device side index computations.
361  m.getContext(),
362  DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
363  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
364  options.overrideIndexBitwidth(indexBitwidth);
365  options.useBarePtrCallConv = useBarePtrCallConv;
366 
367  // Apply in-dialect lowering. In-dialect lowering will replace
368  // ops which need to be lowered further, which is not supported by a
369  // single conversion pass.
370  {
371  RewritePatternSet patterns(m.getContext());
373  // Transform N-D vector.from_elements to 1-D vector.from_elements before
374  // conversion.
375  vector::populateVectorFromElementsUnrollPatterns(patterns);
376  if (failed(applyPatternsGreedily(m, std::move(patterns))))
377  return signalPassFailure();
378  }
379 
380  LLVMTypeConverter converter(m.getContext(), options);
382  RewritePatternSet llvmPatterns(m.getContext());
384 
385  // Set higher benefit, so patterns will run before generic LLVM lowering.
386  populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
387  /*benefit=*/10);
388 
389  llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
390  allowedDialects.end());
391  for (Dialect *dialect : getContext().getLoadedDialects()) {
392  // Skip math patterns as nvvm needs custom math lowering.
393  if (isa<math::MathDialect>(dialect))
394  continue;
395 
396  bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
397  // Empty `allowedDialectsSet` means all dialects are allowed.
398  if (!allowedDialectsSet.empty() && !allowed)
399  continue;
400 
401  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
402  if (!iface) {
403  // Error out if dialect was explicily specified but doesn't implement
404  // conversion interface.
405  if (allowed) {
406  m.emitError()
407  << "dialect does not implement ConvertToLLVMPatternInterface: "
408  << dialect->getNamespace();
409  return signalPassFailure();
410  }
411  continue;
412  }
413 
414  iface->populateConvertToLLVMConversionPatterns(target, converter,
415  llvmPatterns);
416  }
417 
418  populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
419  if (this->hasRedux)
420  populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
422  if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
423  signalPassFailure();
424  }
425 };
426 
427 } // namespace
428 
430  target.addIllegalOp<func::FuncOp>();
431  target.addIllegalOp<cf::AssertOp>();
432  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
433  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
434  target.addIllegalDialect<gpu::GPUDialect>();
435  target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
436  LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437  LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438  LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439  LLVM::SqrtOp>();
440 
441  // TODO: Remove once we support replacing non-root ops.
442  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
443 }
444 
446  // NVVM uses alloca in the default address space to represent private
447  // memory allocations, so drop private annotations. NVVM uses address
448  // space 3 for shared memory. NVVM uses the default address space to
449  // represent global memory.
451  converter, [](gpu::AddressSpace space) -> unsigned {
452  switch (space) {
453  case gpu::AddressSpace::Global:
454  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
455  case gpu::AddressSpace::Workgroup:
456  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
457  case gpu::AddressSpace::Private:
458  return 0;
459  }
460  llvm_unreachable("unknown address space enum value");
461  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
462  });
463  // Lowering for MMAMatrixType.
464  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
465  return convertMMAToLLVMType(type);
466  });
467 }
468 
469 template <typename OpTy>
470 static void populateOpPatterns(const LLVMTypeConverter &converter,
472  PatternBenefit benefit, StringRef f32Func,
473  StringRef f64Func, StringRef f32ApproxFunc = "",
474  StringRef f16Func = "") {
475  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
476  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
477  f32ApproxFunc, f16Func,
478  /*i32Func=*/"", benefit);
479 }
480 
481 template <typename OpTy>
482 static void populateIntOpPatterns(const LLVMTypeConverter &converter,
484  PatternBenefit benefit, StringRef i32Func) {
485  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
486  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
487  benefit);
488 }
489 
490 template <typename OpTy>
491 static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
493  PatternBenefit benefit,
494  StringRef f32Func, StringRef f64Func) {
495  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
496  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
497  /*i32Func=*/"", benefit);
498 }
499 
501  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
502  PatternBenefit benefit) {
503  patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
504 }
505 
507  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
508  PatternBenefit benefit) {
509  populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
510  "__nv_fmod");
511  populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
512  "__nv_fmaxf", "__nv_fmax");
513  populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
514  "__nv_fminf", "__nv_fmin");
515 
516  populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
517  populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
518  "__nv_fabs");
519  populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
520  "__nv_acos");
521  populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
522  "__nv_acosh");
523  populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
524  "__nv_asin");
525  populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
526  "__nv_asinh");
527  populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
528  "__nv_atan");
529  populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
530  "__nv_atan2");
531  populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
532  "__nv_atanh");
533  populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
534  "__nv_cbrt");
535  populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
536  "__nv_ceil");
537  populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
538  "__nv_copysignf", "__nv_copysign");
539  populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
540  "__nv_cos", "__nv_fast_cosf");
541  populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
542  "__nv_cosh");
543  populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
544  "__nv_erf");
545  populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
546  "__nv_erfc");
547  populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
548  "__nv_exp", "__nv_fast_expf");
549  populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
550  "__nv_exp2");
551  populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
552  "__nv_expm1");
553  populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
554  "__nv_floor");
555  populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
556  "__nv_fma");
557  // Note: libdevice uses a different name for 32-bit finite checking
558  populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
559  "__nv_finitef", "__nv_isfinited");
560  populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
561  "__nv_isinfd");
562  populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
563  "__nv_isnand");
564  populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
565  "__nv_log", "__nv_fast_logf");
566  populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
567  "__nv_log10", "__nv_fast_log10f");
568  populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
569  "__nv_log1p");
570  populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
571  "__nv_log2", "__nv_fast_log2f");
572  populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
573  "__nv_pow", "__nv_fast_powf");
574  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
575  "__nv_powif", "__nv_powi");
576  populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
577  "__nv_round");
578  populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
579  "__nv_rintf", "__nv_rint");
580  populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
581  "__nv_rsqrt");
582  populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
583  "__nv_sin", "__nv_fast_sinf");
584  populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
585  "__nv_sinh");
586  populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
587  "__nv_sqrt");
588  populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
589  "__nv_tan", "__nv_fast_tanf");
590  populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
591  "__nv_tanh");
592 }
593 
595  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
596  PatternBenefit benefit) {
599 
600  // TODO: Pass benefit to generated patterns.
601  populateWithGenerated(patterns);
602 
603  patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
604  converter, benefit);
605  patterns.add<
606  gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
607  NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
608  converter, IndexKind::Block, IntrType::Id, benefit);
609  patterns.add<
610  gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
611  NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
612  converter, IndexKind::Block, IntrType::Dim, benefit);
613  patterns.add<
614  gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
615  NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
616  converter, IndexKind::Other, IntrType::Id, benefit);
618  gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
619  NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
620  benefit);
622  gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
623  NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
624  converter, IndexKind::Other, IntrType::Id, benefit);
626  gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
627  NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
628  converter, IndexKind::Other, IntrType::Dim, benefit);
630  gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
631  converter, IndexKind::Grid, IntrType::Id, benefit);
633  gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
634  converter, IndexKind::Grid, IntrType::Dim, benefit);
635  patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
636  converter, benefit);
637 
639  converter, NVVM::kSharedMemoryAlignmentBit, benefit);
640 
641  // Explicitly drop memory space when lowering private memory
642  // attributions since NVVM models it as `alloca`s in the default
643  // memory space and does not support `alloca`s with addrspace(5).
645  converter,
647  /*allocaAddrSpace=*/0,
648  /*workgroupAddrSpace=*/
649  static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
650  StringAttr::get(&converter.getContext(),
651  NVVM::NVVMDialect::getKernelFuncAttrName()),
652  StringAttr::get(&converter.getContext(),
653  NVVM::NVVMDialect::getMaxntidAttrName())},
654  benefit);
655 
656  populateLibDeviceConversionPatterns(converter, patterns, benefit);
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // NVVMTargetAttr convert to LLVM attr interface
661 //===----------------------------------------------------------------------===//
662 
663 namespace {
664 struct NVVMTargetConvertToLLVMAttrInterface
665  : public ConvertToLLVMAttrInterface::ExternalModel<
666  NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
667  /// Configure GPU to NVVM.
668  void populateConvertToLLVMConversionPatterns(
669  Attribute attr, ConversionTarget &target,
670  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
671 };
672 } // namespace
673 
674 void NVVMTargetConvertToLLVMAttrInterface::
675  populateConvertToLLVMConversionPatterns(Attribute attr,
676  ConversionTarget &target,
677  LLVMTypeConverter &typeConverter,
678  RewritePatternSet &patterns) const {
680  configureGpuToNVVMTypeConverter(typeConverter);
682 }
683 
685  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
686  NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
687  });
688 }
static MLIRContext * getContext(OpFoldResult val)
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static llvm::ManagedStatic< PassManagerOptions > options
#define MINUI(lhs, rhs)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition: NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:91
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns that lower certain arith and math dialect ops to libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.