MLIR  20.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1 //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate NVVMIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
38 
39 #include "../GPUCommon/GPUOpsLowering.h"
40 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
41 #include "../GPUCommon/OpToFuncCallLowering.h"
42 #include <optional>
43 
44 namespace mlir {
45 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
46 #include "mlir/Conversion/Passes.h.inc"
47 } // namespace mlir
48 
49 using namespace mlir;
50 
51 namespace {
52 
53 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
54 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
55  switch (mode) {
56  case gpu::ShuffleMode::XOR:
57  return NVVM::ShflKind::bfly;
58  case gpu::ShuffleMode::UP:
59  return NVVM::ShflKind::up;
60  case gpu::ShuffleMode::DOWN:
61  return NVVM::ShflKind::down;
62  case gpu::ShuffleMode::IDX:
63  return NVVM::ShflKind::idx;
64  }
65  llvm_unreachable("unknown shuffle mode");
66 }
67 
68 static std::optional<NVVM::ReduxKind>
69 convertReduxKind(gpu::AllReduceOperation mode) {
70  switch (mode) {
71  case gpu::AllReduceOperation::ADD:
72  return NVVM::ReduxKind::ADD;
73  case gpu::AllReduceOperation::MUL:
74  return std::nullopt;
75  case gpu::AllReduceOperation::MINSI:
76  return NVVM::ReduxKind::MIN;
78  return std::nullopt;
79  case gpu::AllReduceOperation::MINNUMF:
80  return NVVM::ReduxKind::MIN;
81  case gpu::AllReduceOperation::MAXSI:
82  return NVVM::ReduxKind::MAX;
83  case gpu::AllReduceOperation::MAXUI:
84  return std::nullopt;
85  case gpu::AllReduceOperation::MAXNUMF:
86  return NVVM::ReduxKind::MAX;
87  case gpu::AllReduceOperation::AND:
88  return NVVM::ReduxKind::AND;
89  case gpu::AllReduceOperation::OR:
90  return NVVM::ReduxKind::OR;
91  case gpu::AllReduceOperation::XOR:
92  return NVVM::ReduxKind::XOR;
93  case gpu::AllReduceOperation::MINIMUMF:
94  case gpu::AllReduceOperation::MAXIMUMF:
95  return std::nullopt;
96  }
97  return std::nullopt;
98 }
99 
100 /// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
101 /// must be run by the entire subgroup, otherwise it is undefined behaviour.
102 struct GPUSubgroupReduceOpLowering
103  : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
105  LogicalResult
106 
107  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
108  ConversionPatternRewriter &rewriter) const override {
109  if (op.getClusterSize())
110  return rewriter.notifyMatchFailure(
111  op, "lowering for clustered reduce not implemented");
112 
113  if (!op.getUniform())
114  return rewriter.notifyMatchFailure(
115  op, "cannot be lowered to redux as the op must be run "
116  "uniformly (entire subgroup).");
117  if (!op.getValue().getType().isInteger(32))
118  return rewriter.notifyMatchFailure(op, "unsupported data type");
119 
120  std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
121  if (!mode.has_value())
122  return rewriter.notifyMatchFailure(
123  op, "unsupported reduction mode for redux");
124 
125  Location loc = op->getLoc();
126  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
127  Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
128 
129  auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
130  mode.value(), offset);
131 
132  rewriter.replaceOp(op, reduxOp->getResult(0));
133  return success();
134  }
135 };
136 
137 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
139 
140  /// Lowers a shuffle to the corresponding NVVM op.
141  ///
142  /// Convert the `width` argument into an activeMask (a bitmask which specifies
143  /// which threads participate in the shuffle) and a maskAndClamp (specifying
144  /// the highest lane which participates in the shuffle).
145  ///
146  /// %one = llvm.constant(1 : i32) : i32
147  /// %minus_one = llvm.constant(-1 : i32) : i32
148  /// %thirty_two = llvm.constant(32 : i32) : i32
149  /// %num_lanes = llvm.sub %thirty_two, %width : i32
150  /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
151  /// %mask_and_clamp = llvm.sub %width, %one : i32
152  /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
153  /// %mask_and_clamp : !llvm<"{ float, i1 }">
154  /// %shfl_value = llvm.extractvalue %shfl[0] :
155  /// !llvm<"{ float, i1 }">
156  /// %shfl_pred = llvm.extractvalue %shfl[1] :
157  /// !llvm<"{ float, i1 }">
158  LogicalResult
159  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
160  ConversionPatternRewriter &rewriter) const override {
161  Location loc = op->getLoc();
162 
163  auto valueTy = adaptor.getValue().getType();
164  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
165  auto predTy = IntegerType::get(rewriter.getContext(), 1);
166 
167  Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
168  Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
169  Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
170  Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
171  loc, int32Type, thirtyTwo, adaptor.getWidth());
172  // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
173  Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
174  numLeadInactiveLane);
175  Value maskAndClamp;
176  if (op.getMode() == gpu::ShuffleMode::UP) {
177  // Clamp lane: `32 - activeWidth`
178  maskAndClamp = numLeadInactiveLane;
179  } else {
180  // Clamp lane: `activeWidth - 1`
181  maskAndClamp =
182  rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
183  }
184 
185  bool predIsUsed = !op->getResult(1).use_empty();
186  UnitAttr returnValueAndIsValidAttr = nullptr;
187  Type resultTy = valueTy;
188  if (predIsUsed) {
189  returnValueAndIsValidAttr = rewriter.getUnitAttr();
190  resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
191  {valueTy, predTy});
192  }
193  Value shfl = rewriter.create<NVVM::ShflOp>(
194  loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
195  maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
196  if (predIsUsed) {
197  Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
198  Value isActiveSrcLane =
199  rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
200  rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
201  } else {
202  rewriter.replaceOp(op, {shfl, nullptr});
203  }
204  return success();
205  }
206 };
207 
208 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
210 
211  LogicalResult
212  matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override {
214  auto loc = op->getLoc();
215  MLIRContext *context = rewriter.getContext();
216  LLVM::ConstantRangeAttr bounds = nullptr;
217  if (std::optional<APInt> upperBound = op.getUpperBound())
218  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
219  /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
220  else
221  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
222  /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
223  Value newOp =
224  rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
225  // Truncate or extend the result depending on the index bitwidth specified
226  // by the LLVMTypeConverter options.
227  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
228  if (indexBitwidth > 32) {
229  newOp = rewriter.create<LLVM::SExtOp>(
230  loc, IntegerType::get(context, indexBitwidth), newOp);
231  } else if (indexBitwidth < 32) {
232  newOp = rewriter.create<LLVM::TruncOp>(
233  loc, IntegerType::get(context, indexBitwidth), newOp);
234  }
235  rewriter.replaceOp(op, {newOp});
236  return success();
237  }
238 };
239 
240 /// Lowering of cf.assert into a conditional __assertfail.
241 struct AssertOpToAssertfailLowering
242  : public ConvertOpToLLVMPattern<cf::AssertOp> {
244 
245  LogicalResult
246  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
247  ConversionPatternRewriter &rewriter) const override {
248  MLIRContext *ctx = rewriter.getContext();
249  Location loc = assertOp.getLoc();
250  Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
251  Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
252  Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
253  Type ptrType = LLVM::LLVMPointerType::get(ctx);
254  Type voidType = LLVM::LLVMVoidType::get(ctx);
255 
256  // Find or create __assertfail function declaration.
257  auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
258  auto assertfailType = LLVM::LLVMFunctionType::get(
259  voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
260  LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
261  moduleOp, loc, rewriter, "__assertfail", assertfailType);
262  assertfailDecl.setPassthroughAttr(
263  ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
264 
265  // Split blocks and insert conditional branch.
266  // ^before:
267  // ...
268  // cf.cond_br %condition, ^after, ^assert
269  // ^assert:
270  // cf.assert
271  // cf.br ^after
272  // ^after:
273  // ...
274  Block *beforeBlock = assertOp->getBlock();
275  Block *assertBlock =
276  rewriter.splitBlock(beforeBlock, assertOp->getIterator());
277  Block *afterBlock =
278  rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
279  rewriter.setInsertionPointToEnd(beforeBlock);
280  rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
281  assertBlock);
282  rewriter.setInsertionPointToEnd(assertBlock);
283  rewriter.create<cf::BranchOp>(loc, afterBlock);
284 
285  // Continue cf.assert lowering.
286  rewriter.setInsertionPoint(assertOp);
287 
288  // Populate file name, file number and function name from the location of
289  // the AssertOp.
290  StringRef fileName = "(unknown)";
291  StringRef funcName = "(unknown)";
292  int32_t fileLine = 0;
293  while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
294  loc = callSiteLoc.getCallee();
295  if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
296  fileName = fileLineColLoc.getFilename().strref();
297  fileLine = fileLineColLoc.getStartLine();
298  } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
299  funcName = nameLoc.getName().strref();
300  if (auto fileLineColLoc =
301  dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
302  fileName = fileLineColLoc.getFilename().strref();
303  fileLine = fileLineColLoc.getStartLine();
304  }
305  }
306 
307  // Create constants.
308  auto getGlobal = [&](LLVM::GlobalOp global) {
309  // Get a pointer to the format string's first element.
310  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
311  loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
312  global.getSymNameAttr());
313  Value start =
314  rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
315  globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
316  return start;
317  };
318  Value assertMessage = getGlobal(getOrCreateStringConstant(
319  rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
320  Value assertFile = getGlobal(getOrCreateStringConstant(
321  rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
322  Value assertFunc = getGlobal(getOrCreateStringConstant(
323  rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
324  Value assertLine =
325  rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
326  Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
327 
328  // Insert function call to __assertfail.
329  SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
330  assertFunc, c1};
331  rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
332  arguments);
333  return success();
334  }
335 };
336 
337 /// Import the GPU Ops to NVVM Patterns.
338 #include "GPUToNVVM.cpp.inc"
339 
340 /// A pass that replaces all occurrences of GPU device operations with their
341 /// corresponding NVVM equivalent.
342 ///
343 /// This pass only handles device code and is not meant to be run on GPU host
344 /// code.
345 struct LowerGpuOpsToNVVMOpsPass
346  : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
347  using Base::Base;
348 
349  void runOnOperation() override {
350  gpu::GPUModuleOp m = getOperation();
351 
352  // Request C wrapper emission.
353  for (auto func : m.getOps<func::FuncOp>()) {
354  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
356  }
357 
358  // Customize the bitwidth used for the device side index computations.
360  m.getContext(),
361  DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
362  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
363  options.overrideIndexBitwidth(indexBitwidth);
364  options.useBarePtrCallConv = useBarePtrCallConv;
365 
366  // Apply in-dialect lowering. In-dialect lowering will replace
367  // ops which need to be lowered further, which is not supported by a
368  // single conversion pass.
369  {
370  RewritePatternSet patterns(m.getContext());
372  if (failed(applyPatternsGreedily(m, std::move(patterns))))
373  return signalPassFailure();
374  }
375 
376  LLVMTypeConverter converter(m.getContext(), options);
378  RewritePatternSet llvmPatterns(m.getContext());
379 
380  arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
381  cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
382  populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
383  populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
384  populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
385  populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
386  populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
387  if (this->hasRedux)
388  populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
391  if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
392  signalPassFailure();
393  }
394 };
395 
396 } // namespace
397 
399  target.addIllegalOp<func::FuncOp>();
400  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
401  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
402  target.addIllegalDialect<gpu::GPUDialect>();
403  target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
404  LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
405  LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
406  LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
407  LLVM::SinOp, LLVM::SqrtOp>();
408 
409  // TODO: Remove once we support replacing non-root ops.
410  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
411 }
412 
414  // NVVM uses alloca in the default address space to represent private
415  // memory allocations, so drop private annotations. NVVM uses address
416  // space 3 for shared memory. NVVM uses the default address space to
417  // represent global memory.
419  converter, [](gpu::AddressSpace space) -> unsigned {
420  switch (space) {
421  case gpu::AddressSpace::Global:
422  return static_cast<unsigned>(
424  case gpu::AddressSpace::Workgroup:
425  return static_cast<unsigned>(
427  case gpu::AddressSpace::Private:
428  return 0;
429  }
430  llvm_unreachable("unknown address space enum value");
431  return 0;
432  });
433  // Lowering for MMAMatrixType.
434  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
435  return convertMMAToLLVMType(type);
436  });
437 }
438 
439 template <typename OpTy>
440 static void populateOpPatterns(const LLVMTypeConverter &converter,
441  RewritePatternSet &patterns, StringRef f32Func,
442  StringRef f64Func, StringRef f32ApproxFunc = "",
443  StringRef f16Func = "") {
445  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
446  f32ApproxFunc, f16Func);
447 }
448 
450  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
451  patterns.add<GPUSubgroupReduceOpLowering>(converter);
452 }
453 
455  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
458  populateWithGenerated(patterns);
459  patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
460  converter);
461  patterns.add<
462  gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
463  NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
464  converter, IndexKind::Block, IntrType::Id);
465  patterns.add<
466  gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
467  NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
468  converter, IndexKind::Block, IntrType::Dim);
469  patterns.add<
470  gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
471  NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
472  converter, IndexKind::Other, IntrType::Id);
474  gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
475  NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
477  gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
478  NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
479  converter, IndexKind::Other, IntrType::Id);
481  gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
482  NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
483  converter, IndexKind::Other, IntrType::Dim);
485  gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
486  converter, IndexKind::Grid, IntrType::Id);
488  gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
489  converter, IndexKind::Grid, IntrType::Dim);
490  patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
491  converter);
492 
495 
496  // Explicitly drop memory space when lowering private memory
497  // attributions since NVVM models it as `alloca`s in the default
498  // memory space and does not support `alloca`s with addrspace(5).
500  converter,
502  /*allocaAddrSpace=*/0,
503  /*workgroupAddrSpace=*/
504  static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
505  StringAttr::get(&converter.getContext(),
506  NVVM::NVVMDialect::getKernelFuncAttrName()),
507  StringAttr::get(&converter.getContext(),
508  NVVM::NVVMDialect::getMaxntidAttrName())});
509 
510  populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
511  "__nv_fmod");
512  populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
513  "__nv_fabs");
514  populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
515  "__nv_acos");
516  populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
517  "__nv_acosh");
518  populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
519  "__nv_asin");
520  populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
521  "__nv_asinh");
522  populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
523  "__nv_atan");
524  populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
525  "__nv_atan2");
526  populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
527  "__nv_atanh");
528  populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
529  "__nv_cbrt");
530  populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
531  "__nv_ceil");
532  populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
533  "__nv_copysign");
534  populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
535  "__nv_fast_cosf");
536  populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
537  "__nv_cosh");
538  populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
539  populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
540  "__nv_fast_expf");
541  populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
542  "__nv_exp2");
543  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
544  "__nv_expm1");
545  populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
546  "__nv_floor");
547  populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
548  populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
549  "__nv_fast_logf");
550  populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
551  "__nv_log10", "__nv_fast_log10f");
552  populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
553  "__nv_log1p");
554  populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
555  "__nv_log2", "__nv_fast_log2f");
556  populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
557  "__nv_fast_powf");
558  populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
559  "__nv_round");
560  populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
561  "__nv_rint");
562  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
563  "__nv_rsqrt");
564  populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
565  "__nv_fast_sinf");
566  populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
567  "__nv_sinh");
568  populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
569  "__nv_sqrt");
570  populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
571  "__nv_fast_tanf");
572  populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
573  "__nv_tanh");
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // NVVMTargetAttr convert to LLVM attr interface
578 //===----------------------------------------------------------------------===//
579 
580 namespace {
581 struct NVVMTargetConvertToLLVMAttrInterface
582  : public ConvertToLLVMAttrInterface::ExternalModel<
583  NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
584  /// Configure GPU to NVVM.
585  void populateConvertToLLVMConversionPatterns(
586  Attribute attr, ConversionTarget &target,
587  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
588 };
589 } // namespace
590 
591 void NVVMTargetConvertToLLVMAttrInterface::
592  populateConvertToLLVMConversionPatterns(Attribute attr,
593  ConversionTarget &target,
594  LLVMTypeConverter &typeConverter,
595  RewritePatternSet &patterns) const {
597  configureGpuToNVVMTypeConverter(typeConverter);
599 }
600 
602  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
603  NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
604  });
605 }
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
constexpr int kWarpSize
Definition: NVGPUDialect.h:25
static llvm::ManagedStatic< PassManagerOptions > options
#define MINUI(lhs, rhs)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:107
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition: NVVMDialect.h:32
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:37
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:91
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:733
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func depen...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.