MLIR  19.0.0git
LowerGpuOpsToROCDLOps.cpp
Go to the documentation of this file.
1 //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/Passes.h"
20 
41 #include "mlir/Pass/Pass.h"
44 #include "llvm/Support/FormatVariadic.h"
45 
46 #include "../GPUCommon/GPUOpsLowering.h"
47 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
48 #include "../GPUCommon/OpToFuncCallLowering.h"
49 
50 namespace mlir {
51 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
52 #include "mlir/Conversion/Passes.h.inc"
53 } // namespace mlir
54 
55 using namespace mlir;
56 
57 /// Returns true if the given `gpu.func` can be safely called using the bare
58 /// pointer calling convention.
59 static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
60  bool canBeBare = true;
61  for (Type type : func.getArgumentTypes())
62  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
63  canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
64  return canBeBare;
65 }
66 
68  const unsigned indexBitwidth) {
69  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
70  Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
71  Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
72  Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
73  ValueRange{minus1, zero});
74  Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
75  ValueRange{minus1, mbcntLo});
76  return laneId;
77 }
78 static constexpr StringLiteral amdgcnDataLayout =
79  "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
80  "-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:"
81  "128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-"
82  "G1-ni:7:8";
83 
84 namespace {
85 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
87 
89  matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
90  ConversionPatternRewriter &rewriter) const override {
91  auto loc = op->getLoc();
92  MLIRContext *context = rewriter.getContext();
93  // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
94  // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
95 
96  Type intTy = IntegerType::get(context, 32);
97  Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
98  Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
99  Value mbcntLo =
100  rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
101  Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
102  loc, intTy, ValueRange{minus1, mbcntLo});
103  // Truncate or extend the result depending on the index bitwidth specified
104  // by the LLVMTypeConverter options.
105  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
106  if (indexBitwidth > 32) {
107  laneId = rewriter.create<LLVM::SExtOp>(
108  loc, IntegerType::get(context, indexBitwidth), laneId);
109  } else if (indexBitwidth < 32) {
110  laneId = rewriter.create<LLVM::TruncOp>(
111  loc, IntegerType::get(context, indexBitwidth), laneId);
112  }
113  rewriter.replaceOp(op, {laneId});
114  return success();
115  }
116 };
117 
118 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
120 
121  /// Lowers a shuffle to the corresponding ROCDL ops.
122  ///
123  /// Use the `width` argument to see if src lane is participating.
124  /// If not the dstLane would be itself.
125  ///
126  /// Shuffle with DS Bpermute:
127  /// let shflMode = [xor, up, down, idx]
128  /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
129  /// 1. curLaneId = using mbcnt.lo + mbcnt.hi
130  /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
131  /// 3. dstLane = shflMode(curLaneId, step)
132  /// 4. isActiveSrcLane = dstLane < isActiveSrcLane
133  /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
134  /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
135  /// 7. bpermute(dwordAlignedDstLane, shfl_value).
136  ///
138  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
139  ConversionPatternRewriter &rewriter) const override {
140  Location loc = op->getLoc();
141  // TODO: Add support for non 32-bit shuffle values.
142  if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
143  return failure();
144  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
145  Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
146 
147  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
148  Value width = adaptor.getWidth();
149  Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
150  Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
151  Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
152  Value widthOrZeroIfOutside =
153  rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
154  Value dstLane;
155  // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
156  // TODO: Use ds_swizzle for XOR when step/offsets are constants for better
157  // perf.
158  switch (op.getMode()) {
159  case gpu::ShuffleMode::XOR:
160  dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
161  adaptor.getOffset());
162  break;
163  case gpu::ShuffleMode::IDX:
164  dstLane = adaptor.getOffset();
165  break;
166  default:
167  return failure();
168  }
169  Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
170  loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
171  Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
172  dstLane, srcLaneId);
173  Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
174  Value dwordAlignedDstLane =
175  rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
176  Value initShflValue = adaptor.getValue();
177  if (adaptor.getValue().getType().isF32()) {
178  initShflValue =
179  rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
180  }
181  Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
182  loc, int32Type, dwordAlignedDstLane, initShflValue);
183  if (adaptor.getValue().getType().isF32()) {
184  shflValue = rewriter.create<LLVM::BitcastOp>(
185  loc, adaptor.getValue().getType(), shflValue);
186  }
187  rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
188  return success();
189  }
190 };
191 
192 /// Import the GPU Ops to ROCDL Patterns.
193 #include "GPUToROCDL.cpp.inc"
194 
195 // A pass that replaces all occurrences of GPU device operations with their
196 // corresponding ROCDL equivalent.
197 //
198 // This pass only handles device code and is not meant to be run on GPU host
199 // code.
200 struct LowerGpuOpsToROCDLOpsPass
201  : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
202  LowerGpuOpsToROCDLOpsPass() = default;
203  LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
204  bool useBarePtrCallConv,
205  gpu::amd::Runtime runtime) {
206  if (this->chipset.getNumOccurrences() == 0)
207  this->chipset = chipset;
208  if (this->indexBitwidth.getNumOccurrences() == 0)
209  this->indexBitwidth = indexBitwidth;
210  if (this->useBarePtrCallConv.getNumOccurrences() == 0)
211  this->useBarePtrCallConv = useBarePtrCallConv;
212  if (this->runtime.getNumOccurrences() == 0)
213  this->runtime = runtime;
214  }
215 
216  void runOnOperation() override {
217  gpu::GPUModuleOp m = getOperation();
218  MLIRContext *ctx = m.getContext();
219 
220  auto llvmDataLayout = m->getAttrOfType<StringAttr>(
221  LLVM::LLVMDialect::getDataLayoutAttrName());
222  if (!llvmDataLayout) {
223  llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
224  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
225  }
226  // Request C wrapper emission.
227  for (auto func : m.getOps<func::FuncOp>()) {
228  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
229  UnitAttr::get(ctx));
230  }
231 
232  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
233  if (failed(maybeChipset)) {
234  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
235  return signalPassFailure();
236  }
237 
238  /// Customize the bitwidth used for the device side index computations.
240  ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
241  options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
242  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
243  options.overrideIndexBitwidth(indexBitwidth);
244 
245  if (useBarePtrCallConv) {
246  options.useBarePtrCallConv = true;
247  WalkResult canUseBarePointers =
248  m.walk([](gpu::GPUFuncOp func) -> WalkResult {
249  if (canBeCalledWithBarePointers(func))
250  return WalkResult::advance();
251  return WalkResult::interrupt();
252  });
253  if (canUseBarePointers.wasInterrupted()) {
255  "bare pointer calling convention requires all memrefs to "
256  "have static shape and use the identity map");
257  return signalPassFailure();
258  }
259  }
260 
261  // Apply in-dialect lowering. In-dialect lowering will replace
262  // ops which need to be lowered further, which is not supported by a
263  // single conversion pass.
264  {
265  RewritePatternSet patterns(ctx);
266  populateGpuRewritePatterns(patterns);
268  (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
269  }
270 
271  LLVMTypeConverter converter(ctx, options);
273  converter, [](gpu::AddressSpace space) {
274  switch (space) {
275  case gpu::AddressSpace::Global:
276  return 1;
277  case gpu::AddressSpace::Workgroup:
278  return 3;
279  case gpu::AddressSpace::Private:
280  return 5;
281  }
282  llvm_unreachable("unknown address space enum value");
283  return 0;
284  });
285 
286  RewritePatternSet llvmPatterns(ctx);
287 
289  populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
290  *maybeChipset);
291  populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
292  cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
293  populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
294  populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
295  populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
298  if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
299  signalPassFailure();
300  auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
301  auto reqdWorkGroupSizeAttrHelper =
302  rocdlDialect->getReqdWorkGroupSizeAttrHelper();
303  auto flatWorkGroupSizeAttrHelper =
304  rocdlDialect->getFlatWorkGroupSizeAttrHelper();
305  // Manually rewrite known block size attributes so the LLVMIR translation
306  // infrastructure can pick them up.
307  m.walk([&](LLVM::LLVMFuncOp op) {
308  if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
309  op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
310  reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
311  // Also set up the rocdl.flat_work_group_size attribute to prevent
312  // conflicting metadata.
313  uint32_t flatSize = 1;
314  for (uint32_t size : blockSizes.asArrayRef()) {
315  flatSize *= size;
316  }
317  StringAttr flatSizeAttr =
318  StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
319  flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
320  }
321  });
322  }
323 };
324 
325 } // namespace
326 
328  target.addIllegalOp<func::FuncOp>();
329  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
330  target.addLegalDialect<ROCDL::ROCDLDialect>();
331  target.addIllegalDialect<gpu::GPUDialect>();
332  target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
333  LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
334  LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
335  LLVM::SqrtOp>();
336 
337  // TODO: Remove once we support replacing non-root ops.
338  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
339 }
340 
341 template <typename OpTy>
342 static void populateOpPatterns(LLVMTypeConverter &converter,
343  RewritePatternSet &patterns, StringRef f32Func,
344  StringRef f64Func) {
345  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
346  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
347 }
348 
350  LLVMTypeConverter &converter, RewritePatternSet &patterns,
351  mlir::gpu::amd::Runtime runtime) {
353 
354  populateWithGenerated(patterns);
355  patterns
356  .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
357  ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
358  converter, gpu::GPUFuncOp::getKnownBlockSizeAttrName());
360  gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
361  converter, gpu::GPUFuncOp::getKnownGridSizeAttrName());
362  patterns
363  .add<GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
364  ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
365  GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp,
366  ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
367  GPUReturnOpLowering>(converter);
368  patterns.add<GPUFuncOpLowering>(
369  converter,
370  /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
371  /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
372  ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
373  if (Runtime::HIP == runtime) {
374  patterns.add<GPUPrintfOpToHIPLowering>(converter);
375  } else if (Runtime::OpenCL == runtime) {
376  // Use address space = 4 to match the OpenCL definition of printf()
377  patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
378  }
379  // TODO: Add alignment for workgroup memory
380  patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
381 
382  patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
383 
384  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
385  "__ocml_fabs_f64");
386  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
387  "__ocml_atan_f64");
388  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
389  "__ocml_atan2_f64");
390  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
391  "__ocml_cbrt_f64");
392  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
393  "__ocml_ceil_f64");
394  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
395  "__ocml_cos_f64");
396  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
397  "__ocml_exp_f64");
398  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
399  "__ocml_exp2_f64");
400  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
401  "__ocml_expm1_f64");
402  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
403  "__ocml_floor_f64");
404  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
405  "__ocml_fmod_f64");
406  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
407  "__ocml_log_f64");
408  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
409  "__ocml_log10_f64");
410  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
411  "__ocml_log1p_f64");
412  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
413  "__ocml_log2_f64");
414  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
415  "__ocml_pow_f64");
416  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
417  "__ocml_rsqrt_f64");
418  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
419  "__ocml_sin_f64");
420  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
421  "__ocml_sqrt_f64");
422  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
423  "__ocml_tanh_f64");
424  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
425  "__ocml_tan_f64");
426  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
427  "__ocml_erf_f64");
428 }
429 
430 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
431 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
432  unsigned indexBitwidth,
433  bool useBarePtrCallConv,
434  gpu::amd::Runtime runtime) {
435  return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
436  chipset, indexBitwidth, useBarePtrCallConv, runtime);
437 }
static MLIRContext * getContext(OpFoldResult val)
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func)
Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, const unsigned indexBitwidth)
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
The main mechanism for performing data layout queries.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:595
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:377
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:752
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:81
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
std::unique_ptr< OperationPass< gpu::GPUModuleOp > > createLowerGpuOpsToROCDLOpsPass(const std::string &chipset="gfx900", unsigned indexBitwidth=kDeriveIndexBitwidthFromDataLayout, bool useBarePtrCallConv=false, gpu::amd::Runtime runtime=gpu::amd::Runtime::Unknown)
Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func depending on the element type tha...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
static FailureOr< Chipset > parse(StringRef name)
Definition: Chipset.cpp:16