MLIR  22.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1 //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate NVVMIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
33 
34 #include "../GPUCommon/GPUOpsLowering.h"
35 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
36 #include "../GPUCommon/OpToFuncCallLowering.h"
37 #include <optional>
38 
39 namespace mlir {
40 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
41 #include "mlir/Conversion/Passes.h.inc"
42 } // namespace mlir
43 
44 using namespace mlir;
45 
46 namespace {
47 
48 /// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
49 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
50  switch (mode) {
51  case gpu::ShuffleMode::XOR:
52  return NVVM::ShflKind::bfly;
53  case gpu::ShuffleMode::UP:
54  return NVVM::ShflKind::up;
55  case gpu::ShuffleMode::DOWN:
56  return NVVM::ShflKind::down;
57  case gpu::ShuffleMode::IDX:
58  return NVVM::ShflKind::idx;
59  }
60  llvm_unreachable("unknown shuffle mode");
61 }
62 
63 static std::optional<NVVM::ReduxKind>
64 convertReduxKind(gpu::AllReduceOperation mode) {
65  switch (mode) {
66  case gpu::AllReduceOperation::ADD:
67  return NVVM::ReduxKind::ADD;
68  case gpu::AllReduceOperation::MUL:
69  return std::nullopt;
70  case gpu::AllReduceOperation::MINSI:
71  return NVVM::ReduxKind::MIN;
73  return std::nullopt;
74  case gpu::AllReduceOperation::MINNUMF:
75  return NVVM::ReduxKind::MIN;
76  case gpu::AllReduceOperation::MAXSI:
77  return NVVM::ReduxKind::MAX;
78  case gpu::AllReduceOperation::MAXUI:
79  return std::nullopt;
80  case gpu::AllReduceOperation::MAXNUMF:
81  return NVVM::ReduxKind::MAX;
82  case gpu::AllReduceOperation::AND:
83  return NVVM::ReduxKind::AND;
84  case gpu::AllReduceOperation::OR:
85  return NVVM::ReduxKind::OR;
86  case gpu::AllReduceOperation::XOR:
87  return NVVM::ReduxKind::XOR;
88  case gpu::AllReduceOperation::MINIMUMF:
89  case gpu::AllReduceOperation::MAXIMUMF:
90  return std::nullopt;
91  }
92  return std::nullopt;
93 }
94 
95 /// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
96 /// must be run by the entire subgroup, otherwise it is undefined behaviour.
97 struct GPUSubgroupReduceOpLowering
98  : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
100  LogicalResult
101 
102  matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
103  ConversionPatternRewriter &rewriter) const override {
104  if (op.getClusterSize())
105  return rewriter.notifyMatchFailure(
106  op, "lowering for clustered reduce not implemented");
107 
108  if (!op.getUniform())
109  return rewriter.notifyMatchFailure(
110  op, "cannot be lowered to redux as the op must be run "
111  "uniformly (entire subgroup).");
112  if (!op.getValue().getType().isInteger(32))
113  return rewriter.notifyMatchFailure(op, "unsupported data type");
114 
115  std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
116  if (!mode.has_value())
117  return rewriter.notifyMatchFailure(
118  op, "unsupported reduction mode for redux");
119 
120  Location loc = op->getLoc();
121  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
122  Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
123 
124  auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
125  op.getValue(), mode.value(), offset);
126 
127  rewriter.replaceOp(op, reduxOp->getResult(0));
128  return success();
129  }
130 };
131 
132 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
134 
135  /// Lowers a shuffle to the corresponding NVVM op.
136  ///
137  /// Convert the `width` argument into an activeMask (a bitmask which specifies
138  /// which threads participate in the shuffle) and a maskAndClamp (specifying
139  /// the highest lane which participates in the shuffle).
140  ///
141  /// %one = llvm.constant(1 : i32) : i32
142  /// %minus_one = llvm.constant(-1 : i32) : i32
143  /// %thirty_two = llvm.constant(32 : i32) : i32
144  /// %num_lanes = llvm.sub %thirty_two, %width : i32
145  /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
146  /// %mask_and_clamp = llvm.sub %width, %one : i32
147  /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
148  /// %mask_and_clamp : !llvm<"{ float, i1 }">
149  /// %shfl_value = llvm.extractvalue %shfl[0] :
150  /// !llvm<"{ float, i1 }">
151  /// %shfl_pred = llvm.extractvalue %shfl[1] :
152  /// !llvm<"{ float, i1 }">
153  LogicalResult
154  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
155  ConversionPatternRewriter &rewriter) const override {
156  Location loc = op->getLoc();
157 
158  auto valueTy = adaptor.getValue().getType();
159  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
160  auto predTy = IntegerType::get(rewriter.getContext(), 1);
161 
162  Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
163  Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
164  Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
165  Value numLeadInactiveLane = LLVM::SubOp::create(
166  rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
167  // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
168  Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
169  numLeadInactiveLane);
170  Value maskAndClamp;
171  if (op.getMode() == gpu::ShuffleMode::UP) {
172  // Clamp lane: `32 - activeWidth`
173  maskAndClamp = numLeadInactiveLane;
174  } else {
175  // Clamp lane: `activeWidth - 1`
176  maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
177  adaptor.getWidth(), one);
178  }
179 
180  bool predIsUsed = !op->getResult(1).use_empty();
181  UnitAttr returnValueAndIsValidAttr = nullptr;
182  Type resultTy = valueTy;
183  if (predIsUsed) {
184  returnValueAndIsValidAttr = rewriter.getUnitAttr();
185  resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
186  {valueTy, predTy});
187  }
188  Value shfl = NVVM::ShflOp::create(
189  rewriter, loc, resultTy, activeMask, adaptor.getValue(),
190  adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
191  returnValueAndIsValidAttr);
192  if (predIsUsed) {
193  Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
194  Value isActiveSrcLane =
195  LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
196  rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
197  } else {
198  rewriter.replaceOp(op, {shfl, nullptr});
199  }
200  return success();
201  }
202 };
203 
204 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
206 
207  LogicalResult
208  matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
209  ConversionPatternRewriter &rewriter) const override {
210  auto loc = op->getLoc();
211  MLIRContext *context = rewriter.getContext();
212  LLVM::ConstantRangeAttr bounds = nullptr;
213  if (std::optional<APInt> upperBound = op.getUpperBound())
214  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
215  /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
216  else
217  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
218  /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
219  Value newOp =
220  NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
221  // Truncate or extend the result depending on the index bitwidth specified
222  // by the LLVMTypeConverter options.
223  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
224  if (indexBitwidth > 32) {
225  newOp = LLVM::SExtOp::create(
226  rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
227  } else if (indexBitwidth < 32) {
228  newOp = LLVM::TruncOp::create(
229  rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
230  }
231  rewriter.replaceOp(op, {newOp});
232  return success();
233  }
234 };
235 
236 /// Lowering of cf.assert into a conditional __assertfail.
237 struct AssertOpToAssertfailLowering
238  : public ConvertOpToLLVMPattern<cf::AssertOp> {
240 
241  LogicalResult
242  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
243  ConversionPatternRewriter &rewriter) const override {
244  MLIRContext *ctx = rewriter.getContext();
245  Location loc = assertOp.getLoc();
246  Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
247  Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
248  Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
249  Type ptrType = LLVM::LLVMPointerType::get(ctx);
250  Type voidType = LLVM::LLVMVoidType::get(ctx);
251 
252  // Find or create __assertfail function declaration.
253  auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
254  auto assertfailType = LLVM::LLVMFunctionType::get(
255  voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
256  LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
257  moduleOp, loc, rewriter, "__assertfail", assertfailType);
258  assertfailDecl.setPassthroughAttr(
259  ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
260 
261  // Split blocks and insert conditional branch.
262  // ^before:
263  // ...
264  // cf.cond_br %condition, ^after, ^assert
265  // ^assert:
266  // cf.assert
267  // cf.br ^after
268  // ^after:
269  // ...
270  Block *beforeBlock = assertOp->getBlock();
271  Block *assertBlock =
272  rewriter.splitBlock(beforeBlock, assertOp->getIterator());
273  Block *afterBlock =
274  rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
275  rewriter.setInsertionPointToEnd(beforeBlock);
276  cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
277  assertBlock);
278  rewriter.setInsertionPointToEnd(assertBlock);
279  cf::BranchOp::create(rewriter, loc, afterBlock);
280 
281  // Continue cf.assert lowering.
282  rewriter.setInsertionPoint(assertOp);
283 
284  // Populate file name, file number and function name from the location of
285  // the AssertOp.
286  StringRef fileName = "(unknown)";
287  StringRef funcName = "(unknown)";
288  int32_t fileLine = 0;
289  while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
290  loc = callSiteLoc.getCallee();
291  if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
292  fileName = fileLineColLoc.getFilename().strref();
293  fileLine = fileLineColLoc.getStartLine();
294  } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
295  funcName = nameLoc.getName().strref();
296  if (auto fileLineColLoc =
297  dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
298  fileName = fileLineColLoc.getFilename().strref();
299  fileLine = fileLineColLoc.getStartLine();
300  }
301  }
302 
303  // Create constants.
304  auto getGlobal = [&](LLVM::GlobalOp global) {
305  // Get a pointer to the format string's first element.
306  Value globalPtr = LLVM::AddressOfOp::create(
307  rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
308  global.getSymNameAttr());
309  Value start =
310  LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
311  globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
312  return start;
313  };
314  Value assertMessage = getGlobal(getOrCreateStringConstant(
315  rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
316  Value assertFile = getGlobal(getOrCreateStringConstant(
317  rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
318  Value assertFunc = getGlobal(getOrCreateStringConstant(
319  rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
320  Value assertLine =
321  LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
322  Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
323 
324  // Insert function call to __assertfail.
325  SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
326  assertFunc, c1};
327  rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
328  arguments);
329  return success();
330  }
331 };
332 
333 /// Import the GPU Ops to NVVM Patterns.
334 #include "GPUToNVVM.cpp.inc"
335 
336 /// A pass that replaces all occurrences of GPU device operations with their
337 /// corresponding NVVM equivalent.
338 ///
339 /// This pass only handles device code and is not meant to be run on GPU host
340 /// code.
341 struct LowerGpuOpsToNVVMOpsPass final
342  : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
343  using Base::Base;
344 
345  void getDependentDialects(DialectRegistry &registry) const override {
346  Base::getDependentDialects(registry);
348  }
349 
350  void runOnOperation() override {
351  gpu::GPUModuleOp m = getOperation();
352 
353  // Request C wrapper emission.
354  for (auto func : m.getOps<func::FuncOp>()) {
355  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
357  }
358 
359  // Customize the bitwidth used for the device side index computations.
361  m.getContext(),
362  DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
363  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
364  options.overrideIndexBitwidth(indexBitwidth);
365  options.useBarePtrCallConv = useBarePtrCallConv;
366 
367  // Apply in-dialect lowering. In-dialect lowering will replace
368  // ops which need to be lowered further, which is not supported by a
369  // single conversion pass.
370  {
371  RewritePatternSet patterns(m.getContext());
373  // Transform N-D vector.from_elements to 1-D vector.from_elements before
374  // conversion.
375  vector::populateVectorFromElementsUnrollPatterns(patterns);
376  if (failed(applyPatternsGreedily(m, std::move(patterns))))
377  return signalPassFailure();
378  }
379 
380  LLVMTypeConverter converter(m.getContext(), options);
382  RewritePatternSet llvmPatterns(m.getContext());
384 
385  // Set higher benefit, so patterns will run before generic LLVM lowering.
386  populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
387  /*benefit=*/10);
388 
389  llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
390  allowedDialects.end());
391  for (Dialect *dialect : getContext().getLoadedDialects()) {
392  // Skip math patterns as nvvm needs custom math lowering.
393  if (isa<math::MathDialect>(dialect))
394  continue;
395 
396  bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
397  // Empty `allowedDialectsSet` means all dialects are allowed.
398  if (!allowedDialectsSet.empty() && !allowed)
399  continue;
400 
401  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
402  if (!iface) {
403  // Error out if dialect was explicily specified but doesn't implement
404  // conversion interface.
405  if (allowed) {
406  m.emitError()
407  << "dialect does not implement ConvertToLLVMPatternInterface: "
408  << dialect->getNamespace();
409  return signalPassFailure();
410  }
411  continue;
412  }
413 
414  iface->populateConvertToLLVMConversionPatterns(target, converter,
415  llvmPatterns);
416  }
417 
418  populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
419  if (this->hasRedux)
420  populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
422  if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
423  signalPassFailure();
424  }
425 };
426 
427 } // namespace
428 
430  target.addIllegalOp<func::FuncOp>();
431  target.addIllegalOp<cf::AssertOp>();
432  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
433  target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
434  target.addIllegalDialect<gpu::GPUDialect>();
435  target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
436  LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437  LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438  LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439  LLVM::SincosOp, LLVM::SqrtOp>();
440 
441  // TODO: Remove once we support replacing non-root ops.
442  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
443 }
444 
446  // NVVM uses alloca in the default address space to represent private
447  // memory allocations, so drop private annotations. NVVM uses address
448  // space 3 for shared memory. NVVM uses the default address space to
449  // represent global memory.
451  converter, [](gpu::AddressSpace space) -> unsigned {
452  switch (space) {
453  case gpu::AddressSpace::Global:
454  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
455  case gpu::AddressSpace::Workgroup:
456  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
457  case gpu::AddressSpace::Private:
458  return 0;
459  }
460  llvm_unreachable("unknown address space enum value");
461  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
462  });
463  // Lowering for MMAMatrixType.
464  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
465  return convertMMAToLLVMType(type);
466  });
467 }
468 
469 struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
471 
472  LogicalResult
473  matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
474  ConversionPatternRewriter &rewriter) const override {
475  Location loc = op.getLoc();
476  Value input = adaptor.getOperand();
477  Type inputType = input.getType();
478  auto convertedInput = maybeExt(input, rewriter);
479  auto computeType = convertedInput.getType();
480 
481  StringRef sincosFunc;
482  if (isa<Float32Type>(computeType)) {
483  const arith::FastMathFlags flag = op.getFastmath();
484  const bool useApprox =
485  mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
486  sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
487  } else if (isa<Float64Type>(computeType)) {
488  sincosFunc = "__nv_sincos";
489  } else {
490  return rewriter.notifyMatchFailure(op,
491  "unsupported operand type for sincos");
492  }
493 
494  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
495 
496  Value sinPtr, cosPtr;
497  {
498  OpBuilder::InsertionGuard guard(rewriter);
499  auto *scope =
500  op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
501  assert(scope && "Expected op to be inside automatic allocation scope");
502  rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
503  auto one = rewriter.create<LLVM::ConstantOp>(
504  loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
505  sinPtr =
506  rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
507  cosPtr =
508  rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
509  }
510 
511  createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
512  op);
513 
514  auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
515  auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
516 
517  rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
518  maybeTrunc(cosResult, inputType, rewriter)});
519  return success();
520  }
521 
522 private:
523  Value maybeExt(Value operand, PatternRewriter &rewriter) const {
524  if (isa<Float16Type, BFloat16Type>(operand.getType()))
525  return rewriter.create<LLVM::FPExtOp>(
526  operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
527  return operand;
528  }
529 
530  Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
531  if (operand.getType() != type)
532  return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
533  return operand;
534  }
535 
536  void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
537  StringRef funcName, Value input, Value sinPtr,
538  Value cosPtr, Operation *op) const {
539  auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
540  auto ptrType = sinPtr.getType();
541 
542  SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
543  auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
544 
545  auto funcAttr = StringAttr::get(op->getContext(), funcName);
546  auto funcOp =
547  SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
548 
549  if (!funcOp) {
550  auto parentFunc = op->getParentOfType<FunctionOpInterface>();
551  assert(parentFunc && "expected there to be a parent function");
552  OpBuilder b(parentFunc);
553 
554  auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
555  funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
556  }
557 
558  SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
559  rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
560  }
561 };
562 
563 template <typename OpTy>
564 static void populateOpPatterns(const LLVMTypeConverter &converter,
566  PatternBenefit benefit, StringRef f32Func,
567  StringRef f64Func, StringRef f32ApproxFunc = "",
568  StringRef f16Func = "") {
569  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
570  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
571  f32ApproxFunc, f16Func,
572  /*i32Func=*/"", benefit);
573 }
574 
575 template <typename OpTy>
576 static void populateIntOpPatterns(const LLVMTypeConverter &converter,
578  PatternBenefit benefit, StringRef i32Func) {
579  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
580  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
581  benefit);
582 }
583 
584 template <typename OpTy>
585 static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
587  PatternBenefit benefit,
588  StringRef f32Func, StringRef f64Func) {
589  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
590  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
591  /*i32Func=*/"", benefit);
592 }
593 
595  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
596  PatternBenefit benefit) {
597  patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
598 }
599 
601  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
602  PatternBenefit benefit) {
603  populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
604  "__nv_fmod");
605  populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
606  "__nv_fmaxf", "__nv_fmax");
607  populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
608  "__nv_fminf", "__nv_fmin");
609 
610  populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
611  populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
612  "__nv_fabs");
613  populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
614  "__nv_acos");
615  populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
616  "__nv_acosh");
617  populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
618  "__nv_asin");
619  populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
620  "__nv_asinh");
621  populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
622  "__nv_atan");
623  populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
624  "__nv_atan2");
625  populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
626  "__nv_atanh");
627  populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
628  "__nv_cbrt");
629  populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
630  "__nv_ceil");
631  populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
632  "__nv_copysignf", "__nv_copysign");
633  populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
634  "__nv_cos", "__nv_fast_cosf");
635  populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
636  "__nv_cosh");
637  populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
638  "__nv_erf");
639  populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
640  "__nv_erfc");
641  populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
642  "__nv_exp", "__nv_fast_expf");
643  populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
644  "__nv_exp2");
645  populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
646  "__nv_expm1");
647  populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
648  "__nv_floor");
649  populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
650  "__nv_fma");
651  // Note: libdevice uses a different name for 32-bit finite checking
652  populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
653  "__nv_finitef", "__nv_isfinited");
654  populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
655  "__nv_isinfd");
656  populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
657  "__nv_isnand");
658  populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
659  "__nv_log", "__nv_fast_logf");
660  populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
661  "__nv_log10", "__nv_fast_log10f");
662  populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
663  "__nv_log1p");
664  populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
665  "__nv_log2", "__nv_fast_log2f");
666  populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
667  "__nv_pow", "__nv_fast_powf");
668  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
669  "__nv_powif", "__nv_powi");
670  populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
671  "__nv_round");
672  populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
673  "__nv_rintf", "__nv_rint");
674  populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
675  "__nv_rsqrt");
676  populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
677  "__nv_sin", "__nv_fast_sinf");
678  populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
679  "__nv_sinh");
680  populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
681  "__nv_sqrt");
682  populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
683  "__nv_tan", "__nv_fast_tanf");
684  populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
685  "__nv_tanh");
686 
687  // Custom pattern for sincos since it returns two values
688  patterns.add<SincosOpLowering>(converter, benefit);
689 }
690 
692  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
693  PatternBenefit benefit) {
696 
697  // TODO: Pass benefit to generated patterns.
698  populateWithGenerated(patterns);
699 
700  patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
701  converter, benefit);
702  patterns.add<
703  gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
704  NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
705  converter, IndexKind::Block, IntrType::Id, benefit);
706  patterns.add<
707  gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
708  NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
709  converter, IndexKind::Block, IntrType::Dim, benefit);
710  patterns.add<
711  gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
712  NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
713  converter, IndexKind::Other, IntrType::Id, benefit);
715  gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
716  NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
717  benefit);
719  gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
720  NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
721  converter, IndexKind::Other, IntrType::Id, benefit);
723  gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
724  NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
725  converter, IndexKind::Other, IntrType::Dim, benefit);
727  gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
728  converter, IndexKind::Grid, IntrType::Id, benefit);
730  gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
731  converter, IndexKind::Grid, IntrType::Dim, benefit);
732  patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
733  converter, benefit);
734 
736  converter, NVVM::kSharedMemoryAlignmentBit, benefit);
737 
738  // Explicitly drop memory space when lowering private memory
739  // attributions since NVVM models it as `alloca`s in the default
740  // memory space and does not support `alloca`s with addrspace(5).
742  converter,
744  /*allocaAddrSpace=*/0,
745  /*workgroupAddrSpace=*/
746  static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
747  StringAttr::get(&converter.getContext(),
748  NVVM::NVVMDialect::getKernelFuncAttrName()),
749  StringAttr::get(&converter.getContext(),
750  NVVM::NVVMDialect::getMaxntidAttrName())},
751  benefit);
752 
753  populateLibDeviceConversionPatterns(converter, patterns, benefit);
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // NVVMTargetAttr convert to LLVM attr interface
758 //===----------------------------------------------------------------------===//
759 
760 namespace {
761 struct NVVMTargetConvertToLLVMAttrInterface
762  : public ConvertToLLVMAttrInterface::ExternalModel<
763  NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
764  /// Configure GPU to NVVM.
765  void populateConvertToLLVMConversionPatterns(
766  Attribute attr, ConversionTarget &target,
767  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
768 };
769 } // namespace
770 
771 void NVVMTargetConvertToLLVMAttrInterface::
772  populateConvertToLLVMConversionPatterns(Attribute attr,
773  ConversionTarget &target,
774  LLVMTypeConverter &typeConverter,
775  RewritePatternSet &patterns) const {
777  configureGpuToNVVMTypeConverter(typeConverter);
779 }
780 
782  registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
783  NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
784  });
785 }
static MLIRContext * getContext(OpFoldResult val)
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
constexpr int kWarpSize
Definition: NVGPUDialect.h:26
static llvm::ManagedStatic< PassManagerOptions > options
#define MINUI(lhs, rhs)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
An instance of this location represents a tuple of file, line number, and column number.
Definition: Location.h:174
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext & getContext() const
Returns the MLIR context.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
Definition: Location.h:60
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition: NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:91
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns that lower certain arith and math dialect ops to libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.