MLIR  21.0.0git
LowerGpuOpsToROCDLOps.cpp
Go to the documentation of this file.
1 //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate ROCDLIR operations for higher-level
10 // GPU operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassManager.h"
18 
39 #include "mlir/Pass/Pass.h"
42 
43 #include "../GPUCommon/GPUOpsLowering.h"
44 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
45 
46 namespace mlir {
47 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
48 #include "mlir/Conversion/Passes.h.inc"
49 } // namespace mlir
50 
51 using namespace mlir;
52 
53 // Truncate or extend the result depending on the index bitwidth specified
54 // by the LLVMTypeConverter options.
56  Location loc, Value value,
57  const LLVMTypeConverter &converter) {
58  int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
59  int64_t indexBitwidth = converter.getIndexTypeBitwidth();
60  auto indexBitwidthType =
61  IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
62  // TODO: use <=> in C++20.
63  if (indexBitwidth > intWidth) {
64  return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
65  }
66  if (indexBitwidth < intWidth) {
67  return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
68  }
69  return value;
70 }
71 
72 /// Returns true if the given `gpu.func` can be safely called using the bare
73 /// pointer calling convention.
74 static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
75  bool canBeBare = true;
76  for (Type type : func.getArgumentTypes())
77  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
78  canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
79  return canBeBare;
80 }
81 
83  const unsigned indexBitwidth) {
84  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
85  Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
86  Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
87  Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
88  ValueRange{minus1, zero});
89  Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
90  ValueRange{minus1, mbcntLo});
91  return laneId;
92 }
93 static constexpr StringLiteral amdgcnDataLayout =
94  "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
95  "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
96  "32-v32:"
97  "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
98  "64-S32-A5-G1-ni:7:8:9";
99 
100 namespace {
101 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
103 
104  LogicalResult
105  matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
106  ConversionPatternRewriter &rewriter) const override {
107  auto loc = op->getLoc();
108  MLIRContext *context = rewriter.getContext();
109  // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
110  // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
111 
112  Type intTy = IntegerType::get(context, 32);
113  Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
114  Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
115  Value mbcntLo =
116  rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
117  Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
118  loc, intTy, ValueRange{minus1, mbcntLo});
119  // Truncate or extend the result depending on the index bitwidth specified
120  // by the LLVMTypeConverter options.
121  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
122  if (indexBitwidth > 32) {
123  laneId = rewriter.create<LLVM::SExtOp>(
124  loc, IntegerType::get(context, indexBitwidth), laneId);
125  } else if (indexBitwidth < 32) {
126  laneId = rewriter.create<LLVM::TruncOp>(
127  loc, IntegerType::get(context, indexBitwidth), laneId);
128  }
129  rewriter.replaceOp(op, {laneId});
130  return success();
131  }
132 };
133 
134 struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
136 
137  GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter,
138  amdgpu::Chipset chipset)
140  chipset(chipset) {}
141 
142  LogicalResult
143  matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
144  ConversionPatternRewriter &rewriter) const override {
145  LLVM::ConstantRangeAttr bounds = nullptr;
146  bool isBeforeGfx10 = chipset.majorVersion < 10;
147  if (auto upperBoundAttr = op.getUpperBoundAttr()) {
148  bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
149  /*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
150  /*upper=*/op.getUpperBoundAttr().getInt() + 1);
151  }
152  Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
153  op.getLoc(), rewriter.getI32Type(), bounds);
154  wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
155  *getTypeConverter());
156  rewriter.replaceOp(op, {wavefrontOp});
157  return success();
158  }
159 
160  const amdgpu::Chipset chipset;
161 };
162 
163 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
165 
166  /// Lowers a shuffle to the corresponding ROCDL ops.
167  ///
168  /// Use the `width` argument to see if src lane is participating.
169  /// If not the dstLane would be itself.
170  ///
171  /// Shuffle with DS Bpermute:
172  /// let shflMode = [xor, up, down, idx]
173  /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
174  /// 1. curLaneId = using mbcnt.lo + mbcnt.hi
175  /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
176  /// 3. dstLane = shflMode(curLaneId, step)
177  /// 4. isActiveSrcLane = dstLane < isActiveSrcLane
178  /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
179  /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
180  /// 7. bpermute(dwordAlignedDstLane, shfl_value).
181  ///
182  LogicalResult
183  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
184  ConversionPatternRewriter &rewriter) const override {
185  Location loc = op->getLoc();
186  Value initShflValue = adaptor.getValue();
187 
188  const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
189  Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
190 
191  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
192  Value width = adaptor.getWidth();
193  Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
194  Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
195  Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
196  Value widthOrZeroIfOutside =
197  rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
198  Value dstLane;
199 
200  switch (op.getMode()) {
201  case gpu::ShuffleMode::UP:
202  dstLane = rewriter.create<LLVM::SubOp>(loc, int32Type, srcLaneId,
203  adaptor.getOffset());
204  break;
205  case gpu::ShuffleMode::DOWN:
206  dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
207  adaptor.getOffset());
208  break;
209  case gpu::ShuffleMode::XOR:
210  dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
211  adaptor.getOffset());
212  break;
213  case gpu::ShuffleMode::IDX:
214  dstLane = adaptor.getOffset();
215  break;
216  }
217  Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
218  loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
219  Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
220  dstLane, srcLaneId);
221  Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
222  Value dwordAlignedDstLane =
223  rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
224 
225  SmallVector<Value> decomposed =
226  LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
227  SmallVector<Value> swizzled;
228  for (Value v : decomposed) {
229  Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
230  dwordAlignedDstLane, v);
231  swizzled.emplace_back(res);
232  }
233  Value shflValue =
234  LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
235  rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
236  return success();
237  }
238 };
239 
240 /// Import the GPU Ops to ROCDL Patterns.
241 #include "GPUToROCDL.cpp.inc"
242 
243 // A pass that replaces all occurrences of GPU device operations with their
244 // corresponding ROCDL equivalent.
245 //
246 // This pass only handles device code and is not meant to be run on GPU host
247 // code.
248 struct LowerGpuOpsToROCDLOpsPass final
249  : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
250  LowerGpuOpsToROCDLOpsPass() = default;
251  LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
252  bool useBarePtrCallConv,
253  gpu::amd::Runtime runtime) {
254  if (this->chipset.getNumOccurrences() == 0)
255  this->chipset = chipset;
256  if (this->indexBitwidth.getNumOccurrences() == 0)
257  this->indexBitwidth = indexBitwidth;
258  if (this->useBarePtrCallConv.getNumOccurrences() == 0)
259  this->useBarePtrCallConv = useBarePtrCallConv;
260  if (this->runtime.getNumOccurrences() == 0)
261  this->runtime = runtime;
262  }
263 
264  void getDependentDialects(DialectRegistry &registry) const override {
265  Base::getDependentDialects(registry);
267  }
268 
269  void runOnOperation() override {
270  gpu::GPUModuleOp m = getOperation();
271  MLIRContext *ctx = m.getContext();
272 
273  auto llvmDataLayout = m->getAttrOfType<StringAttr>(
274  LLVM::LLVMDialect::getDataLayoutAttrName());
275  if (!llvmDataLayout) {
276  llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
277  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
278  }
279  // Request C wrapper emission.
280  for (auto func : m.getOps<func::FuncOp>()) {
281  func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
282  UnitAttr::get(ctx));
283  }
284 
285  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
286  if (failed(maybeChipset)) {
287  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
288  return signalPassFailure();
289  }
290 
291  /// Customize the bitwidth used for the device side index computations.
293  ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
294  options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
295  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
296  options.overrideIndexBitwidth(indexBitwidth);
297 
298  if (useBarePtrCallConv) {
299  options.useBarePtrCallConv = true;
300  WalkResult canUseBarePointers =
301  m.walk([](gpu::GPUFuncOp func) -> WalkResult {
302  if (canBeCalledWithBarePointers(func))
303  return WalkResult::advance();
304  return WalkResult::interrupt();
305  });
306  if (canUseBarePointers.wasInterrupted()) {
308  "bare pointer calling convention requires all memrefs to "
309  "have static shape and use the identity map");
310  return signalPassFailure();
311  }
312  }
313 
314  // Apply in-dialect lowering. In-dialect lowering will replace
315  // ops which need to be lowered further, which is not supported by a
316  // single conversion pass.
317  {
321  (void)applyPatternsGreedily(m, std::move(patterns));
322  }
323 
324  LLVMTypeConverter converter(ctx, options);
326  converter, [](gpu::AddressSpace space) {
327  switch (space) {
328  case gpu::AddressSpace::Global:
329  return 1;
330  case gpu::AddressSpace::Workgroup:
331  return 3;
332  case gpu::AddressSpace::Private:
333  return 5;
334  }
335  llvm_unreachable("unknown address space enum value");
336  return 0;
337  });
338 
339  RewritePatternSet llvmPatterns(ctx);
341 
342  llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
343  allowedDialects.end());
344  for (Dialect *dialect : ctx->getLoadedDialects()) {
345  bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
346  // Empty `allowedDialectsSet` means all dialects are allowed.
347  if (!allowedDialectsSet.empty() && !allowed)
348  continue;
349 
350  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
351  if (!iface) {
352  // Error out if dialect was explicily specified but doesn't implement
353  // conversion interface.
354  if (allowed) {
355  m.emitError()
356  << "dialect does not implement ConvertToLLVMPatternInterface: "
357  << dialect->getNamespace();
358  return signalPassFailure();
359  }
360  continue;
361  }
362 
363  iface->populateConvertToLLVMConversionPatterns(target, converter,
364  llvmPatterns);
365  }
366 
367  populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
368  *maybeChipset);
369  populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
370  *maybeChipset);
372  if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
373  signalPassFailure();
374  auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
375  auto reqdWorkGroupSizeAttrHelper =
376  rocdlDialect->getReqdWorkGroupSizeAttrHelper();
377  auto flatWorkGroupSizeAttrHelper =
378  rocdlDialect->getFlatWorkGroupSizeAttrHelper();
379  // Manually rewrite known block size attributes so the LLVMIR translation
380  // infrastructure can pick them up.
381  m.walk([&](LLVM::LLVMFuncOp op) {
382  if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
383  auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
384  // Also set up the rocdl.flat_work_group_size attribute to prevent
385  // conflicting metadata.
386  uint32_t flatSize = 1;
387  for (uint32_t size : blockSizes.asArrayRef()) {
388  flatSize *= size;
389  }
390  StringAttr flatSizeAttr =
391  StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
392  flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
393  }
394  });
395  }
396 };
397 
398 } // namespace
399 
401  target.addIllegalOp<func::FuncOp>();
402  target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
403  target.addLegalDialect<ROCDL::ROCDLDialect>();
404  target.addIllegalDialect<gpu::GPUDialect>();
405  target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
406  LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
407  LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
408  // These ops are legal for f32 type.
409  target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
410  return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
411  });
412  // TODO: Remove once we support replacing non-root ops.
413  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
414 }
415 
417  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
418  mlir::gpu::amd::Runtime runtime, amdgpu::Chipset chipset) {
422  auto *rocdlDialect =
423  converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
424  populateWithGenerated(patterns);
425  patterns.add<
426  gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
427  ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
428  converter, IndexKind::Block, IntrType::Id);
430  gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
431  converter, IndexKind::Grid, IntrType::Id);
432  patterns.add<
433  gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
434  ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(
435  converter, IndexKind::Block, IntrType::Dim);
437  gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
438  converter, IndexKind::Grid, IntrType::Dim);
439  patterns.add<GPUReturnOpLowering>(converter);
441  converter,
443  /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
444  /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
445  rocdlDialect->getKernelAttrHelper().getName(),
446  rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName()});
447  if (Runtime::HIP == runtime) {
448  patterns.add<GPUPrintfOpToHIPLowering>(converter);
449  } else if (Runtime::OpenCL == runtime) {
450  // Use address space = 4 to match the OpenCL definition of printf()
451  patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
452  }
453  // TODO: Add alignment for workgroup memory
455 
456  patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
457  patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
458 
460 }
461 
462 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
463 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
464  unsigned indexBitwidth,
465  bool useBarePtrCallConv,
466  gpu::amd::Runtime runtime) {
467  return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
468  chipset, indexBitwidth, useBarePtrCallConv, runtime);
469 }
static MLIRContext * getContext(OpFoldResult val)
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, const unsigned indexBitwidth)
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, Location loc, Value value, const LLVMTypeConverter &converter)
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getI32Type()
Definition: Builders.cpp:62
MLIRContext * getContext() const
Definition: Builders.h:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:205
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition: Pattern.cpp:451
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition: Pattern.cpp:412
Runtime
Potential runtimes for AMD GPU kernels.
Definition: Runtimes.h:15
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:91
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
std::unique_ptr< OperationPass< gpu::GPUModuleOp > > createLowerGpuOpsToROCDLOpsPass(const std::string &chipset="gfx900", unsigned indexBitwidth=kDeriveIndexBitwidthFromDataLayout, bool useBarePtrCallConv=false, gpu::amd::Runtime runtime=gpu::amd::Runtime::Unknown)
Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from Math to ROCDL calls.
Definition: MathToROCDL.cpp:46
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14