MLIR  21.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1 //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate NVVMIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
33 
34 #include "../GPUCommon/GPUOpsLowering.h"
35 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
36 #include "../GPUCommon/OpToFuncCallLowering.h"
37 #include <optional>
38 
39 namespace mlir {
40 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
41 #include "mlir/Conversion/Passes.h.inc"
42 } // namespace mlir
43 
44 using namespace mlir;
45 
46 namespace {
47 
48 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
49 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
50  switch (mode) {
51  case gpu::ShuffleMode::XOR:
52  return NVVM::ShflKind::bfly;
53  case gpu::ShuffleMode::UP:
54  return NVVM::ShflKind::up;
55  case gpu::ShuffleMode::DOWN:
56  return NVVM::ShflKind::down;
57  case gpu::ShuffleMode::IDX:
58  return NVVM::ShflKind::idx;
59  }
60  llvm_unreachable("unknown shuffle mode");
61 }
62 
63 static std::optional<NVVM::ReduxKind>
64 convertReduxKind(gpu::AllReduceOperation mode) {
65  switch (mode) {
66  case gpu::AllReduceOperation::ADD:
67  return NVVM::ReduxKind::ADD;
68  case gpu::AllReduceOperation::MUL:
69  return std::nullopt;
70  case gpu::AllReduceOperation::MINSI:
71  return NVVM::ReduxKind::MIN;
73  return std::nullopt;
74  case gpu::AllReduceOperation::MINNUMF:
75  return NVVM::ReduxKind::MIN;
76  case gpu::AllReduceOperation::MAXSI:
77  return NVVM::ReduxKind::MAX;
78  case gpu::AllReduceOperation::MAXUI:
79  return std::nullopt;
80  case gpu::AllReduceOperation::MAXNUMF:
81  return NVVM::ReduxKind::MAX;
82  case gpu::AllReduceOperation::AND:
83  return NVVM::ReduxKind::AND;
84  case gpu::AllReduceOperation::OR:
85  return NVVM::ReduxKind::OR;
86  case gpu::AllReduceOperation::XOR:
87  return NVVM::ReduxKind::XOR;
88  case gpu::AllReduceOperation::MINIMUMF:
89  case gpu::AllReduceOperation::MAXIMUMF:
90  return std::nullopt;
91  }
92  return std::nullopt;
93 }
94 
95 /// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
96 /// must be run by the entire subgroup, otherwise it is undefined behaviour.
97 struct GPUSubgroupReduceOpLowering
98  : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
100  LogicalResult
101 
102  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
103  ConversionPatternRewriter &rewriter) const override {
104  if (op.getClusterSize())
105  return rewriter.notifyMatchFailure(
106  op, "lowering for clustered reduce not implemented");
107 
108  if (!op.getUniform())
109  return rewriter.notifyMatchFailure(
110  op, "cannot be lowered to redux as the op must be run "
111  "uniformly (entire subgroup).");
112  if (!op.getValue().getType().isInteger(32))
113  return rewriter.notifyMatchFailure(op, "unsupported data type");
114 
115  std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
116  if (!mode.has_value())
117  return rewriter.notifyMatchFailure(
118  op, "unsupported reduction mode for redux");
119 
120  Location loc = op->getLoc();
121  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
122  Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
123 
124  auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
125  mode.value(), offset);
126 
127  rewriter.replaceOp(op, reduxOp->getResult(0));
128  return success();
129  }
130 };
131 
132 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
134 
135  /// Lowers a shuffle to the corresponding NVVM op.
136  ///
137  /// Convert the `width` argument into an activeMask (a bitmask which specifies
138  /// which threads participate in the shuffle) and a maskAndClamp (specifying
139  /// the highest lane which participates in the shuffle).
140  ///
141  /// %one = llvm.constant(1 : i32) : i32
142  /// %minus_one = llvm.constant(-1 : i32) : i32
143  /// %thirty_two = llvm.constant(32 : i32) : i32
144  /// %num_lanes = llvm.sub %thirty_two, %width : i32
145  /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
146  /// %mask_and_clamp = llvm.sub %width, %one : i32
147  /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
148  /// %mask_and_clamp : !llvm<"{ float, i1 }">
149  /// %shfl_value = llvm.extractvalue %shfl[0] :
150  /// !llvm<"{ float, i1 }">
151  /// %shfl_pred = llvm.extractvalue %shfl[1] :
152  /// !llvm<"{ float, i1 }">
153  LogicalResult
154  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
155  ConversionPatternRewriter &rewriter) const override {
156  Location loc = op->getLoc();
157 
158  auto valueTy = adaptor.getValue().getType();
159  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
160  auto predTy = IntegerType::get(rewriter.getContext(), 1);
161 
162  Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
163  Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
164  Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
165  Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
166  loc, int32Type, thirtyTwo, adaptor.getWidth());
167  // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
168  Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
169  numLeadInactiveLane);
170  Value maskAndClamp;
171  if (op.getMode() == gpu::ShuffleMode::UP) {
172  // Clamp lane: `32 - activeWidth`
173  maskAndClamp = numLeadInactiveLane;
174  } else {
175  // Clamp lane: `activeWidth - 1`
176  maskAndClamp =
177  rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
178  }
179 
180  bool predIsUsed = !op->getResult(1).use_empty();
181  UnitAttr returnValueAndIsValidAttr = nullptr;
182  Type resultTy = valueTy;
183  if (predIsUsed) {
184  returnValueAndIsValidAttr = rewriter.getUnitAttr();
185  resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
186  {valueTy, predTy});
187  }
188  Value shfl = rewriter.create<NVVM::ShflOp>(
189  loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
190  maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
191  if (predIsUsed) {
192  Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
193  Value isActiveSrcLane =
194  rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
195  rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
196  } else {
197  rewriter.replaceOp(op, {shfl, nullptr});
198  }
199  return success();
200  }
201 };
202 
203 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
205 
206  LogicalResult
207  matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
208  ConversionPatternRewriter &rewriter) const override {
209  auto loc = op->getLoc();
210  MLIRContext *context = rewriter.getContext();
211  LLVM::ConstantRangeAttr bounds = nullptr;
212  if (std::optional<APInt> upperBound = op.getUpperBound())
213  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
214  /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
215  else
216  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
217  /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
218  Value newOp =
219  rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
220  // Truncate or extend the result depending on the index bitwidth specified
221  // by the LLVMTypeConverter options.
222  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
223  if (indexBitwidth > 32) {
224  newOp = rewriter.create<LLVM::SExtOp>(
225  loc, IntegerType::get(context, indexBitwidth), newOp);
226  } else if (indexBitwidth < 32) {
227  newOp = rewriter.create<LLVM::TruncOp>(
228  loc, IntegerType::get(context, indexBitwidth), newOp);
229  }
230  rewriter.replaceOp(op, {newOp});
231  return success();
232  }
233 };
234 
235 /// Lowering of cf.assert into a conditional __assertfail.
236 struct AssertOpToAssertfailLowering
237  : public ConvertOpToLLVMPattern<cf::AssertOp> {
239 
240  LogicalResult
241  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
242  ConversionPatternRewriter &rewriter) const override {
243  MLIRContext *ctx = rewriter.getContext();
244  Location loc = assertOp.getLoc();
245  Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
246  Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
247  Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
248  Type ptrType = LLVM::LLVMPointerType::get(ctx);
249  Type voidType = LLVM::LLVMVoidType::get(ctx);
250 
251  // Find or create __assertfail function declaration.
252  auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
253  auto assertfailType = LLVM::LLVMFunctionType::get(
254  voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
255  LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
256  moduleOp, loc, rewriter, "__assertfail", assertfailType);
257  assertfailDecl.setPassthroughAttr(
258  ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
259 
260  // Split blocks and insert conditional branch.
261  // ^before:
262  // ...
263  // cf.cond_br %condition, ^after, ^assert
264  // ^assert:
265  // cf.assert
266  // cf.br ^after
267  // ^after:
268  // ...
269  Block *beforeBlock = assertOp->getBlock();
270  Block *assertBlock =
271  rewriter.splitBlock(beforeBlock, assertOp->getIterator());
272  Block *afterBlock =
273  rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
274  rewriter.setInsertionPointToEnd(beforeBlock);
275  rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
276  assertBlock);
277  rewriter.setInsertionPointToEnd(assertBlock);
278  rewriter.create<cf::BranchOp>(loc, afterBlock);
279 
280  // Continue cf.assert lowering.
281  rewriter.setInsertionPoint(assertOp);
282 
283  // Populate file name, file number and function name from the location of
284  // the AssertOp.
285  StringRef fileName = "(unknown)";
286  StringRef funcName = "(unknown)";
287  int32_t fileLine = 0;
288  while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
289  loc = callSiteLoc.getCallee();
290  if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
291  fileName = fileLineColLoc.getFilename().strref();
292  fileLine = fileLineColLoc.getStartLine();
293  } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
294  funcName = nameLoc.getName().strref();
295  if (auto fileLineColLoc =
296  dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
297  fileName = fileLineColLoc.getFilename().strref();
298  fileLine = fileLineColLoc.getStartLine();
299  }
300  }
301 
302  // Create constants.
303  auto getGlobal = [&](LLVM::GlobalOp global) {
304  // Get a pointer to the format string's first element.
305  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
306  loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
307  global.getSymNameAttr());
308  Value start =
309  rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
310  globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
311  return start;
312  };
313  Value assertMessage = getGlobal(getOrCreateStringConstant(
314  rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
315  Value assertFile = getGlobal(getOrCreateStringConstant(
316  rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
317  Value assertFunc = getGlobal(getOrCreateStringConstant(
318  rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
319  Value assertLine =
320  rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
321  Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
322 
323  // Insert function call to __assertfail.
324  SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
325  assertFunc, c1};
326  rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
327  arguments);
328  return success();
329  }
330 };
331 
332 /// Import the GPU Ops to NVVM Patterns.
333 #include "GPUToNVVM.cpp.inc"
334 
335 /// A pass that replaces all occurrences of GPU device operations with their
336 /// corresponding NVVM equivalent.
337 ///
338 /// This pass only handles device code and is not meant to be run on GPU host
339 /// code.
340 struct LowerGpuOpsToNVVMOpsPass final
341  : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
342  using Base::Base;
343 
344  void getDependentDialects(DialectRegistry &registry) const override {
345  Base::getDependentDialects(registry);
347  }
348 
349  void runOnOperation() override {
350  gpu::GPUModuleOp m = getOperation();
351 
352  // Request C wrapper emission.
353  for (auto func : m.getOps<func::FuncOp>()) {
354  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
356  }
357 
358  // Customize the bitwidth used for the device side index computations.
360  m.getContext(),
361  DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
362  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
363  options.overrideIndexBitwidth(indexBitwidth);
364  options.useBarePtrCallConv = useBarePtrCallConv;
365 
366  // Apply in-dialect lowering. In-dialect lowering will replace
367  // ops which need to be lowered further, which is not supported by a
368  // single conversion pass.
369  {
370  RewritePatternSet patterns(m.getContext());
372  if (failed(applyPatternsGreedily(m, std::move(patterns))))
373  return signalPassFailure();
374  }
375 
376  LLVMTypeConverter converter(m.getContext(), options);
378  RewritePatternSet llvmPatterns(m.getContext());
380 
381  // Set higher benefit, so patterns will run before generic LLVM lowering.
382  populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
383  /*benefit=*/10);
384 
385  llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
386  allowedDialects.end());
387  for (Dialect *dialect : getContext().getLoadedDialects()) {
388  // Skip math patterns as nvvm needs custom math lowering.
389  if (isa<math::MathDialect>(dialect))
390  continue;
391 
392  bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
393  // Empty `allowedDialectsSet` means all dialects are allowed.
394  if (!allowedDialectsSet.empty() && !allowed)
395  continue;
396 
397  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
398  if (!iface) {
399  // Error out if dialect was explicily specified but doesn't implement
400  // conversion interface.
401  if (allowed) {
402  m.emitError()
403  << "dialect does not implement ConvertToLLVMPatternInterface: "
404  << dialect->getNamespace();
405  return signalPassFailure();
406  }
407  continue;
408  }
409 
410  iface->populateConvertToLLVMConversionPatterns(target, converter,
411  llvmPatterns);
412  }
413 
414  populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
415  if (this->hasRedux)
416  populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
418  if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
419  signalPassFailure();
420  }
421 };
422 
423 } // namespace
424 
426  target.addIllegalOp<func::FuncOp>();
427  target.addIllegalOp<cf::AssertOp>();
428  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
429  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
430  target.addIllegalDialect<gpu::GPUDialect>();
431  target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
432  LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
433  LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
434  LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
435  LLVM::SinOp, LLVM::SqrtOp>();
436 
437  // TODO: Remove once we support replacing non-root ops.
438  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
439 }
440 
442  // NVVM uses alloca in the default address space to represent private
443  // memory allocations, so drop private annotations. NVVM uses address
444  // space 3 for shared memory. NVVM uses the default address space to
445  // represent global memory.
447  converter, [](gpu::AddressSpace space) -> unsigned {
448  switch (space) {
449  case gpu::AddressSpace::Global:
450  return static_cast<unsigned>(
452  case gpu::AddressSpace::Workgroup:
453  return static_cast<unsigned>(
455  case gpu::AddressSpace::Private:
456  return 0;
457  }
458  llvm_unreachable("unknown address space enum value");
459  return 0;
460  });
461  // Lowering for MMAMatrixType.
462  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
463  return convertMMAToLLVMType(type);
464  });
465 }
466 
467 template <typename OpTy>
468 static void populateOpPatterns(const LLVMTypeConverter &converter,
470  PatternBenefit benefit, StringRef f32Func,
471  StringRef f64Func, StringRef f32ApproxFunc = "",
472  StringRef f16Func = "") {
473  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
474  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
475  f32ApproxFunc, f16Func,
476  /*i32Func=*/"", benefit);
477 }
478 
479 template <typename OpTy>
480 static void populateIntOpPatterns(const LLVMTypeConverter &converter,
482  PatternBenefit benefit, StringRef i32Func) {
483  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
484  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
485  benefit);
486 }
487 
488 template <typename OpTy>
489 static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
491  PatternBenefit benefit,
492  StringRef f32Func, StringRef f64Func) {
493  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
494  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
495  /*i32Func=*/"", benefit);
496 }
497 
499  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
500  PatternBenefit benefit) {
501  patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
502 }
503 
505  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
506  PatternBenefit benefit) {
507  populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
508  "__nv_fmod");
509  populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
510  "__nv_fmaxf", "__nv_fmax");
511  populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
512  "__nv_fminf", "__nv_fmin");
513 
514  populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
515  populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
516  "__nv_fabs");
517  populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
518  "__nv_acos");
519  populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
520  "__nv_acosh");
521  populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
522  "__nv_asin");
523  populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
524  "__nv_asinh");
525  populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
526  "__nv_atan");
527  populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
528  "__nv_atan2");
529  populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
530  "__nv_atanh");
531  populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
532  "__nv_cbrt");
533  populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
534  "__nv_ceil");
535  populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
536  "__nv_copysignf", "__nv_copysign");
537  populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
538  "__nv_cos", "__nv_fast_cosf");
539  populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
540  "__nv_cosh");
541  populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
542  "__nv_erf");
543  populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
544  "__nv_erfc");
545  populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
546  "__nv_exp", "__nv_fast_expf");
547  populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
548  "__nv_exp2");
549  populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
550  "__nv_expm1");
551  populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
552  "__nv_floor");
553  populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
554  "__nv_fma");
555  // Note: libdevice uses a different name for 32-bit finite checking
556  populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
557  "__nv_finitef", "__nv_isfinited");
558  populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
559  "__nv_isinfd");
560  populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
561  "__nv_isnand");
562  populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
563  "__nv_log", "__nv_fast_logf");
564  populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
565  "__nv_log10", "__nv_fast_log10f");
566  populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
567  "__nv_log1p");
568  populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
569  "__nv_log2", "__nv_fast_log2f");
570  populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
571  "__nv_pow", "__nv_fast_powf");
572  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
573  "__nv_powif", "__nv_powi");
574  populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
575  "__nv_round");
576  populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
577  "__nv_rintf", "__nv_rint");
578  populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
579  "__nv_rsqrt");
580  populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
581  "__nv_sin", "__nv_fast_sinf");
582  populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
583  "__nv_sinh");
584  populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
585  "__nv_sqrt");
586  populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
587  "__nv_tan", "__nv_fast_tanf");
588  populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
589  "__nv_tanh");
590 }
591 
593  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
594  PatternBenefit benefit) {
597 
598  // TODO: Pass benefit to generated patterns.
599  populateWithGenerated(patterns);
600 
601  patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
602  converter, benefit);
603  patterns.add<
604  gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
605  NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
606  converter, IndexKind::Block, IntrType::Id, benefit);
607  patterns.add<
608  gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
609  NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
610  converter, IndexKind::Block, IntrType::Dim, benefit);
611  patterns.add<
612  gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
613  NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
614  converter, IndexKind::Other, IntrType::Id, benefit);
616  gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
617  NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
618  benefit);
620  gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
621  NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
622  converter, IndexKind::Other, IntrType::Id, benefit);
624  gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
625  NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
626  converter, IndexKind::Other, IntrType::Dim, benefit);
628  gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
629  converter, IndexKind::Grid, IntrType::Id, benefit);
631  gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
632  converter, IndexKind::Grid, IntrType::Dim, benefit);
633  patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
634  converter, benefit);
635 
637  converter, NVVM::kSharedMemoryAlignmentBit, benefit);
638 
639  // Explicitly drop memory space when lowering private memory
640  // attributions since NVVM models it as `alloca`s in the default
641  // memory space and does not support `alloca`s with addrspace(5).
643  converter,
645  /*allocaAddrSpace=*/0,
646  /*workgroupAddrSpace=*/
647  static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
648  StringAttr::get(&converter.getContext(),
649  NVVM::NVVMDialect::getKernelFuncAttrName()),
650  StringAttr::get(&converter.getContext(),
651  NVVM::NVVMDialect::getMaxntidAttrName())},
652  benefit);
653 
654  populateLibDeviceConversionPatterns(converter, patterns, benefit);
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // NVVMTargetAttr convert to LLVM attr interface
659 //===----------------------------------------------------------------------===//
660 
661 namespace {
662 struct NVVMTargetConvertToLLVMAttrInterface
663  : public ConvertToLLVMAttrInterface::ExternalModel<
664  NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
665  /// Configure GPU to NVVM.
666  void populateConvertToLLVMConversionPatterns(
667  Attribute attr, ConversionTarget &target,
668  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
669 };
670 } // namespace
671 
672 void NVVMTargetConvertToLLVMAttrInterface::
673  populateConvertToLLVMConversionPatterns(Attribute attr,
674  ConversionTarget &target,
675  LLVMTypeConverter &typeConverter,
676  RewritePatternSet &patterns) const {
678  configureGpuToNVVMTypeConverter(typeConverter);
680 }
681 
683  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
684  NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
685  });
686 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static llvm::ManagedStatic< PassManagerOptions > options
#define MINUI(lhs, rhs)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:148
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition: NVVMDialect.h:33
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:38
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:72
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns that lower certain arith and math dialect ops to libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.