MLIR  22.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1 //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate NVVMIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
33 
34 #include "../GPUCommon/GPUOpsLowering.h"
35 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
36 #include "../GPUCommon/OpToFuncCallLowering.h"
37 #include <optional>
38 
39 namespace mlir {
40 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
41 #include "mlir/Conversion/Passes.h.inc"
42 } // namespace mlir
43 
44 using namespace mlir;
45 
46 namespace {
47 
48 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
49 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
50  switch (mode) {
51  case gpu::ShuffleMode::XOR:
52  return NVVM::ShflKind::bfly;
53  case gpu::ShuffleMode::UP:
54  return NVVM::ShflKind::up;
55  case gpu::ShuffleMode::DOWN:
56  return NVVM::ShflKind::down;
57  case gpu::ShuffleMode::IDX:
58  return NVVM::ShflKind::idx;
59  }
60  llvm_unreachable("unknown shuffle mode");
61 }
62 
63 static std::optional<NVVM::ReduxKind>
64 convertReduxKind(gpu::AllReduceOperation mode) {
65  switch (mode) {
66  case gpu::AllReduceOperation::ADD:
67  return NVVM::ReduxKind::ADD;
68  case gpu::AllReduceOperation::MUL:
69  return std::nullopt;
70  case gpu::AllReduceOperation::MINSI:
71  return NVVM::ReduxKind::MIN;
73  return std::nullopt;
74  case gpu::AllReduceOperation::MINNUMF:
75  return NVVM::ReduxKind::MIN;
76  case gpu::AllReduceOperation::MAXSI:
77  return NVVM::ReduxKind::MAX;
78  case gpu::AllReduceOperation::MAXUI:
79  return std::nullopt;
80  case gpu::AllReduceOperation::MAXNUMF:
81  return NVVM::ReduxKind::MAX;
82  case gpu::AllReduceOperation::AND:
83  return NVVM::ReduxKind::AND;
84  case gpu::AllReduceOperation::OR:
85  return NVVM::ReduxKind::OR;
86  case gpu::AllReduceOperation::XOR:
87  return NVVM::ReduxKind::XOR;
88  case gpu::AllReduceOperation::MINIMUMF:
89  case gpu::AllReduceOperation::MAXIMUMF:
90  return std::nullopt;
91  }
92  return std::nullopt;
93 }
94 
95 /// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
96 /// must be run by the entire subgroup, otherwise it is undefined behaviour.
97 struct GPUSubgroupReduceOpLowering
98  : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
100  LogicalResult
101 
102  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
103  ConversionPatternRewriter &rewriter) const override {
104  if (op.getClusterSize())
105  return rewriter.notifyMatchFailure(
106  op, "lowering for clustered reduce not implemented");
107 
108  if (!op.getUniform())
109  return rewriter.notifyMatchFailure(
110  op, "cannot be lowered to redux as the op must be run "
111  "uniformly (entire subgroup).");
112  if (!op.getValue().getType().isInteger(32))
113  return rewriter.notifyMatchFailure(op, "unsupported data type");
114 
115  std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
116  if (!mode.has_value())
117  return rewriter.notifyMatchFailure(
118  op, "unsupported reduction mode for redux");
119 
120  Location loc = op->getLoc();
121  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
122  Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
123 
124  auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
125  op.getValue(), mode.value(), offset);
126 
127  rewriter.replaceOp(op, reduxOp->getResult(0));
128  return success();
129  }
130 };
131 
132 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
134 
135  /// Lowers a shuffle to the corresponding NVVM op.
136  ///
137  /// Convert the `width` argument into an activeMask (a bitmask which specifies
138  /// which threads participate in the shuffle) and a maskAndClamp (specifying
139  /// the highest lane which participates in the shuffle).
140  ///
141  /// %one = llvm.constant(1 : i32) : i32
142  /// %minus_one = llvm.constant(-1 : i32) : i32
143  /// %thirty_two = llvm.constant(32 : i32) : i32
144  /// %num_lanes = llvm.sub %thirty_two, %width : i32
145  /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
146  /// %mask_and_clamp = llvm.sub %width, %one : i32
147  /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
148  /// %mask_and_clamp : !llvm<"{ float, i1 }">
149  /// %shfl_value = llvm.extractvalue %shfl[0] :
150  /// !llvm<"{ float, i1 }">
151  /// %shfl_pred = llvm.extractvalue %shfl[1] :
152  /// !llvm<"{ float, i1 }">
153  LogicalResult
154  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
155  ConversionPatternRewriter &rewriter) const override {
156  Location loc = op->getLoc();
157 
158  auto valueTy = adaptor.getValue().getType();
159  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
160  auto predTy = IntegerType::get(rewriter.getContext(), 1);
161 
162  Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
163  Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
164  Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
165  Value numLeadInactiveLane = LLVM::SubOp::create(
166  rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
167  // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
168  Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
169  numLeadInactiveLane);
170  Value maskAndClamp;
171  if (op.getMode() == gpu::ShuffleMode::UP) {
172  // Clamp lane: `32 - activeWidth`
173  maskAndClamp = numLeadInactiveLane;
174  } else {
175  // Clamp lane: `activeWidth - 1`
176  maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
177  adaptor.getWidth(), one);
178  }
179 
180  bool predIsUsed = !op->getResult(1).use_empty();
181  UnitAttr returnValueAndIsValidAttr = nullptr;
182  Type resultTy = valueTy;
183  if (predIsUsed) {
184  returnValueAndIsValidAttr = rewriter.getUnitAttr();
185  resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
186  {valueTy, predTy});
187  }
188  Value shfl = NVVM::ShflOp::create(
189  rewriter, loc, resultTy, activeMask, adaptor.getValue(),
190  adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
191  returnValueAndIsValidAttr);
192  if (predIsUsed) {
193  Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
194  Value isActiveSrcLane =
195  LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
196  rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
197  } else {
198  rewriter.replaceOp(op, {shfl, nullptr});
199  }
200  return success();
201  }
202 };
203 
204 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
206 
207  LogicalResult
208  matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
209  ConversionPatternRewriter &rewriter) const override {
210  auto loc = op->getLoc();
211  MLIRContext *context = rewriter.getContext();
212  LLVM::ConstantRangeAttr bounds = nullptr;
213  if (std::optional<APInt> upperBound = op.getUpperBound())
214  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
215  /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
216  else
217  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
218  /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
219  Value newOp =
220  NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
221  // Truncate or extend the result depending on the index bitwidth specified
222  // by the LLVMTypeConverter options.
223  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
224  if (indexBitwidth > 32) {
225  newOp = LLVM::SExtOp::create(
226  rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
227  } else if (indexBitwidth < 32) {
228  newOp = LLVM::TruncOp::create(
229  rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
230  }
231  rewriter.replaceOp(op, {newOp});
232  return success();
233  }
234 };
235 
236 /// Lowering of cf.assert into a conditional __assertfail.
237 struct AssertOpToAssertfailLowering
238  : public ConvertOpToLLVMPattern<cf::AssertOp> {
240 
241  LogicalResult
242  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
243  ConversionPatternRewriter &rewriter) const override {
244  MLIRContext *ctx = rewriter.getContext();
245  Location loc = assertOp.getLoc();
246  Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
247  Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
248  Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
249  Type ptrType = LLVM::LLVMPointerType::get(ctx);
250  Type voidType = LLVM::LLVMVoidType::get(ctx);
251 
252  // Find or create __assertfail function declaration.
253  auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
254  auto assertfailType = LLVM::LLVMFunctionType::get(
255  voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
256  LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
257  moduleOp, loc, rewriter, "__assertfail", assertfailType);
258  assertfailDecl.setPassthroughAttr(
259  ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
260 
261  // Split blocks and insert conditional branch.
262  // ^before:
263  // ...
264  // cf.cond_br %condition, ^after, ^assert
265  // ^assert:
266  // cf.assert
267  // cf.br ^after
268  // ^after:
269  // ...
270  Block *beforeBlock = assertOp->getBlock();
271  Block *assertBlock =
272  rewriter.splitBlock(beforeBlock, assertOp->getIterator());
273  Block *afterBlock =
274  rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
275  rewriter.setInsertionPointToEnd(beforeBlock);
276  cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
277  assertBlock);
278  rewriter.setInsertionPointToEnd(assertBlock);
279  cf::BranchOp::create(rewriter, loc, afterBlock);
280 
281  // Continue cf.assert lowering.
282  rewriter.setInsertionPoint(assertOp);
283 
284  // Populate file name, file number and function name from the location of
285  // the AssertOp.
286  StringRef fileName = "(unknown)";
287  StringRef funcName = "(unknown)";
288  int32_t fileLine = 0;
289  while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
290  loc = callSiteLoc.getCallee();
291  if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
292  fileName = fileLineColLoc.getFilename().strref();
293  fileLine = fileLineColLoc.getStartLine();
294  } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
295  funcName = nameLoc.getName().strref();
296  if (auto fileLineColLoc =
297  dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
298  fileName = fileLineColLoc.getFilename().strref();
299  fileLine = fileLineColLoc.getStartLine();
300  }
301  }
302 
303  // Create constants.
304  auto getGlobal = [&](LLVM::GlobalOp global) {
305  // Get a pointer to the format string's first element.
306  Value globalPtr = LLVM::AddressOfOp::create(
307  rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
308  global.getSymNameAttr());
309  Value start =
310  LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
311  globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
312  return start;
313  };
314  Value assertMessage = getGlobal(getOrCreateStringConstant(
315  rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
316  Value assertFile = getGlobal(getOrCreateStringConstant(
317  rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
318  Value assertFunc = getGlobal(getOrCreateStringConstant(
319  rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
320  Value assertLine =
321  LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
322  Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
323 
324  // Insert function call to __assertfail.
325  SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
326  assertFunc, c1};
327  rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
328  arguments);
329  return success();
330  }
331 };
332 
333 /// Import the GPU Ops to NVVM Patterns.
334 #include "GPUToNVVM.cpp.inc"
335 
336 /// A pass that replaces all occurrences of GPU device operations with their
337 /// corresponding NVVM equivalent.
338 ///
339 /// This pass only handles device code and is not meant to be run on GPU host
340 /// code.
341 struct LowerGpuOpsToNVVMOpsPass final
342  : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
343  using Base::Base;
344 
345  void getDependentDialects(DialectRegistry &registry) const override {
346  Base::getDependentDialects(registry);
348  }
349 
350  void runOnOperation() override {
351  gpu::GPUModuleOp m = getOperation();
352 
353  // Request C wrapper emission.
354  for (auto func : m.getOps<func::FuncOp>()) {
355  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
357  }
358 
359  // Customize the bitwidth used for the device side index computations.
361  m.getContext(),
362  DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
363  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
364  options.overrideIndexBitwidth(indexBitwidth);
365  options.useBarePtrCallConv = useBarePtrCallConv;
366 
367  // Apply in-dialect lowering. In-dialect lowering will replace
368  // ops which need to be lowered further, which is not supported by a
369  // single conversion pass.
370  {
371  RewritePatternSet patterns(m.getContext());
373  // Transform N-D vector.from_elements to 1-D vector.from_elements before
374  // conversion.
376  if (failed(applyPatternsGreedily(m, std::move(patterns))))
377  return signalPassFailure();
378  }
379 
380  LLVMTypeConverter converter(m.getContext(), options);
382  RewritePatternSet llvmPatterns(m.getContext());
384 
385  // Set higher benefit, so patterns will run before generic LLVM lowering.
386  populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
387  /*benefit=*/10);
388 
389  llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
390  allowedDialects.end());
391  for (Dialect *dialect : getContext().getLoadedDialects()) {
392  // Skip math patterns as nvvm needs custom math lowering.
393  if (isa<math::MathDialect>(dialect))
394  continue;
395 
396  bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
397  // Empty `allowedDialectsSet` means all dialects are allowed.
398  if (!allowedDialectsSet.empty() && !allowed)
399  continue;
400 
401  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
402  if (!iface) {
403  // Error out if dialect was explicily specified but doesn't implement
404  // conversion interface.
405  if (allowed) {
406  m.emitError()
407  << "dialect does not implement ConvertToLLVMPatternInterface: "
408  << dialect->getNamespace();
409  return signalPassFailure();
410  }
411  continue;
412  }
413 
414  iface->populateConvertToLLVMConversionPatterns(target, converter,
415  llvmPatterns);
416  }
417 
418  populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
419  if (this->hasRedux)
420  populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
422  if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
423  signalPassFailure();
424  }
425 };
426 
427 } // namespace
428 
430  target.addIllegalOp<func::FuncOp>();
431  target.addIllegalOp<cf::AssertOp>();
432  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
433  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
434  target.addIllegalDialect<gpu::GPUDialect>();
435  target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
436  LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437  LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438  LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439  LLVM::SqrtOp>();
440 
441  // TODO: Remove once we support replacing non-root ops.
442  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
443 }
444 
446  // NVVM uses alloca in the default address space to represent private
447  // memory allocations, so drop private annotations. NVVM uses address
448  // space 3 for shared memory. NVVM uses the default address space to
449  // represent global memory.
451  converter, [](gpu::AddressSpace space) -> unsigned {
452  switch (space) {
453  case gpu::AddressSpace::Global:
454  return static_cast<unsigned>(
456  case gpu::AddressSpace::Workgroup:
457  return static_cast<unsigned>(
459  case gpu::AddressSpace::Private:
460  return 0;
461  }
462  llvm_unreachable("unknown address space enum value");
463  return 0;
464  });
465  // Lowering for MMAMatrixType.
466  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
467  return convertMMAToLLVMType(type);
468  });
469 }
470 
471 template <typename OpTy>
472 static void populateOpPatterns(const LLVMTypeConverter &converter,
474  PatternBenefit benefit, StringRef f32Func,
475  StringRef f64Func, StringRef f32ApproxFunc = "",
476  StringRef f16Func = "") {
477  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
478  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
479  f32ApproxFunc, f16Func,
480  /*i32Func=*/"", benefit);
481 }
482 
483 template <typename OpTy>
484 static void populateIntOpPatterns(const LLVMTypeConverter &converter,
486  PatternBenefit benefit, StringRef i32Func) {
487  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
488  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
489  benefit);
490 }
491 
492 template <typename OpTy>
493 static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
495  PatternBenefit benefit,
496  StringRef f32Func, StringRef f64Func) {
497  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
498  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
499  /*i32Func=*/"", benefit);
500 }
501 
503  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
504  PatternBenefit benefit) {
505  patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
506 }
507 
509  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
510  PatternBenefit benefit) {
511  populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
512  "__nv_fmod");
513  populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
514  "__nv_fmaxf", "__nv_fmax");
515  populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
516  "__nv_fminf", "__nv_fmin");
517 
518  populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
519  populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
520  "__nv_fabs");
521  populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
522  "__nv_acos");
523  populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
524  "__nv_acosh");
525  populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
526  "__nv_asin");
527  populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
528  "__nv_asinh");
529  populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
530  "__nv_atan");
531  populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
532  "__nv_atan2");
533  populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
534  "__nv_atanh");
535  populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
536  "__nv_cbrt");
537  populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
538  "__nv_ceil");
539  populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
540  "__nv_copysignf", "__nv_copysign");
541  populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
542  "__nv_cos", "__nv_fast_cosf");
543  populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
544  "__nv_cosh");
545  populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
546  "__nv_erf");
547  populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
548  "__nv_erfc");
549  populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
550  "__nv_exp", "__nv_fast_expf");
551  populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
552  "__nv_exp2");
553  populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
554  "__nv_expm1");
555  populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
556  "__nv_floor");
557  populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
558  "__nv_fma");
559  // Note: libdevice uses a different name for 32-bit finite checking
560  populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
561  "__nv_finitef", "__nv_isfinited");
562  populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
563  "__nv_isinfd");
564  populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
565  "__nv_isnand");
566  populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
567  "__nv_log", "__nv_fast_logf");
568  populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
569  "__nv_log10", "__nv_fast_log10f");
570  populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
571  "__nv_log1p");
572  populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
573  "__nv_log2", "__nv_fast_log2f");
574  populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
575  "__nv_pow", "__nv_fast_powf");
576  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
577  "__nv_powif", "__nv_powi");
578  populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
579  "__nv_round");
580  populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
581  "__nv_rintf", "__nv_rint");
582  populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
583  "__nv_rsqrt");
584  populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
585  "__nv_sin", "__nv_fast_sinf");
586  populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
587  "__nv_sinh");
588  populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
589  "__nv_sqrt");
590  populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
591  "__nv_tan", "__nv_fast_tanf");
592  populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
593  "__nv_tanh");
594 }
595 
597  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
598  PatternBenefit benefit) {
601 
602  // TODO: Pass benefit to generated patterns.
603  populateWithGenerated(patterns);
604 
605  patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
606  converter, benefit);
607  patterns.add<
608  gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
609  NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
610  converter, IndexKind::Block, IntrType::Id, benefit);
611  patterns.add<
612  gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
613  NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
614  converter, IndexKind::Block, IntrType::Dim, benefit);
615  patterns.add<
616  gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
617  NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
618  converter, IndexKind::Other, IntrType::Id, benefit);
620  gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
621  NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
622  benefit);
624  gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
625  NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
626  converter, IndexKind::Other, IntrType::Id, benefit);
628  gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
629  NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
630  converter, IndexKind::Other, IntrType::Dim, benefit);
632  gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
633  converter, IndexKind::Grid, IntrType::Id, benefit);
635  gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
636  converter, IndexKind::Grid, IntrType::Dim, benefit);
637  patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
638  converter, benefit);
639 
641  converter, NVVM::kSharedMemoryAlignmentBit, benefit);
642 
643  // Explicitly drop memory space when lowering private memory
644  // attributions since NVVM models it as `alloca`s in the default
645  // memory space and does not support `alloca`s with addrspace(5).
647  converter,
649  /*allocaAddrSpace=*/0,
650  /*workgroupAddrSpace=*/
651  static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
652  StringAttr::get(&converter.getContext(),
653  NVVM::NVVMDialect::getKernelFuncAttrName()),
654  StringAttr::get(&converter.getContext(),
655  NVVM::NVVMDialect::getMaxntidAttrName())},
656  benefit);
657 
658  populateLibDeviceConversionPatterns(converter, patterns, benefit);
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // NVVMTargetAttr convert to LLVM attr interface
663 //===----------------------------------------------------------------------===//
664 
665 namespace {
666 struct NVVMTargetConvertToLLVMAttrInterface
667  : public ConvertToLLVMAttrInterface::ExternalModel<
668  NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
669  /// Configure GPU to NVVM.
670  void populateConvertToLLVMConversionPatterns(
671  Attribute attr, ConversionTarget &target,
672  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
673 };
674 } // namespace
675 
676 void NVVMTargetConvertToLLVMAttrInterface::
677  populateConvertToLLVMConversionPatterns(Attribute attr,
678  ConversionTarget &target,
679  LLVMTypeConverter &typeConverter,
680  RewritePatternSet &patterns) const {
682  configureGpuToNVVMTypeConverter(typeConverter);
684 }
685 
687  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
688  NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
689  });
690 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static llvm::ManagedStatic< PassManagerOptions > options
#define MINUI(lhs, rhs)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition: NVVMDialect.h:35
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:42
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:91
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns that lower certain arith and math dialect ops to libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.