46#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
47#include "mlir/Conversion/Passes.h.inc"
59 auto indexBitwidthType =
62 if (indexBitwidth > intWidth) {
63 return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
65 if (indexBitwidth < intWidth) {
66 return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
74 bool canBeBare =
true;
75 for (
Type type :
func.getArgumentTypes())
76 if (
auto memrefTy = dyn_cast<BaseMemRefType>(type))
82 auto int32Type = IntegerType::get(rewriter.
getContext(), 32);
86 LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.
getUnitAttr());
88 LLVM::LLVMDialect::getRangeAttrName(),
89 LLVM::ConstantRangeAttr::get(rewriter.
getContext(), APInt::getZero(32),
92 LLVM::LLVMDialect::getRangeAttrName(),
93 LLVM::ConstantRangeAttr::get(rewriter.
getContext(), APInt::getZero(32),
95 Value mbcntLo = ROCDL::MbcntLoOp::create(
96 rewriter, loc, int32Type, minus1, zero, {},
99 Value laneId = ROCDL::MbcntHiOp::create(
100 rewriter, loc, int32Type, minus1, mbcntLo, {},
106 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
107 "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
109 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
110 "64-S32-A5-G1-ni:7:8:9";
117 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
118 ConversionPatternRewriter &rewriter)
const override {
131 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
132 if (indexBitwidth > 32) {
133 laneId = LLVM::SExtOp::create(
134 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
135 }
else if (indexBitwidth < 32) {
136 laneId = LLVM::TruncOp::create(
137 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
139 rewriter.replaceOp(op, {laneId});
153 matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
154 ConversionPatternRewriter &rewriter)
const override {
155 LLVM::ConstantRangeAttr bounds =
nullptr;
157 if (
auto upperBoundAttr = op.getUpperBoundAttr()) {
158 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
159 32, isBeforeGfx10 ? 64 : 32,
160 op.getUpperBoundAttr().getInt() + 1);
162 Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
163 rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
165 *getTypeConverter());
166 rewriter.replaceOp(op, {wavefrontOp});
173static bool isSupportedReadLaneType(
Type type) {
175 if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
176 LLVM::LLVMPointerType>(type))
179 if (
auto intType = dyn_cast<IntegerType>(type))
180 return llvm::is_contained({16, 32, 64},
181 static_cast<int>(intType.getWidth()));
183 if (
auto vecType = dyn_cast<VectorType>(type)) {
184 Type elementType = vecType.getElementType();
188 if (vecType.getNumElements() == 2 &&
189 (isa<Float16Type, BFloat16Type>(elementType) ||
197struct GPUSubgroupBroadcastOpToROCDL
202 matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter)
const override {
204 Value src = adaptor.getSrc();
205 if (isSupportedReadLaneType(src.
getType())) {
206 Value result = createReadlaneOp(op, adaptor, rewriter, src);
207 rewriter.replaceOp(op,
result);
211 Type i32 = rewriter.getI32Type();
217 results.reserve(decomposed.size());
218 for (
Value v : decomposed)
219 results.emplace_back(createReadlaneOp(op, adaptor, rewriter, v));
222 rewriter.replaceOp(op,
result);
227 static Value createReadlaneOp(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter,
230 if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
231 return ROCDL::ReadlaneOp::create(rewriter, op.getLoc(), src.
getType(),
232 src, adaptor.getLane());
234 return ROCDL::ReadfirstlaneOp::create(rewriter, op.getLoc(),
260 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter)
const override {
263 Value initShflValue = adaptor.getValue();
267 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
268 Value width = adaptor.getWidth();
269 Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
270 Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
271 Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
272 Value widthOrZeroIfOutside =
273 LLVM::AndOp::create(rewriter, loc, int32Type,
add, negwidth);
276 switch (op.getMode()) {
277 case gpu::ShuffleMode::UP:
278 dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
279 adaptor.getOffset());
281 case gpu::ShuffleMode::DOWN:
282 dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
283 adaptor.getOffset());
285 case gpu::ShuffleMode::XOR:
286 dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
287 adaptor.getOffset());
289 case gpu::ShuffleMode::IDX:
290 dstLane = adaptor.getOffset();
293 Value isActiveSrcLane = LLVM::ICmpOp::create(
294 rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
295 Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
297 Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
298 Value dwordAlignedDstLane =
299 LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
304 for (
Value v : decomposed) {
305 Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
306 dwordAlignedDstLane, v);
307 swizzled.emplace_back(res);
311 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
317#include "GPUToROCDL.cpp.inc"
324struct LowerGpuOpsToROCDLOpsPass final
325 :
public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
329 Base::getDependentDialects(registry);
333 void runOnOperation()
override {
334 gpu::GPUModuleOp m = getOperation();
337 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
338 LLVM::LLVMDialect::getDataLayoutAttrName());
339 if (!llvmDataLayout) {
341 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
344 for (
auto func : m.getOps<func::FuncOp>()) {
345 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
350 if (failed(maybeChipset)) {
351 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
352 return signalPassFailure();
357 ctx,
DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
358 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
360 options.overrideIndexBitwidth(indexBitwidth);
362 if (useBarePtrCallConv) {
363 options.useBarePtrCallConv =
true;
372 "bare pointer calling convention requires all memrefs to "
373 "have static shape and use the identity map");
374 return signalPassFailure();
394 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
395 allowedDialects.end());
397 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
399 if (!allowedDialectsSet.empty() && !allowed)
402 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
408 <<
"dialect does not implement ConvertToLLVMPatternInterface: "
409 << dialect->getNamespace();
410 return signalPassFailure();
415 iface->populateConvertToLLVMConversionPatterns(
target, converter,
424 if (failed(applyPartialConversion(m,
target, std::move(llvmPatterns))))
426 auto *rocdlDialect =
getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
427 auto reqdWorkGroupSizeAttrHelper =
428 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
429 auto flatWorkGroupSizeAttrHelper =
430 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
433 m.walk([&](LLVM::LLVMFuncOp op) {
434 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
435 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
438 uint32_t flatSize = 1;
439 for (uint32_t size : blockSizes.asArrayRef()) {
442 StringAttr flatSizeAttr =
443 StringAttr::get(ctx, Twine(flatSize) +
"," + Twine(flatSize));
444 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
453 target.addIllegalOp<func::FuncOp>();
454 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
455 target.addLegalDialect<ROCDL::ROCDLDialect>();
456 target.addIllegalDialect<gpu::GPUDialect>();
457 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
458 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
459 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
461 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](
Operation *op) {
462 return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
465 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
479 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
480 converter, IndexKind::Block, IntrType::Id);
482 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
483 converter, IndexKind::Grid, IntrType::Id);
486 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(
487 converter, IndexKind::Block, IntrType::Dim);
489 gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
490 converter, IndexKind::Grid, IntrType::Dim);
495 ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
496 ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
497 rocdlDialect->getKernelAttrHelper().getName(),
498 rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName(),
502 }
else if (Runtime::OpenCL ==
runtime) {
509 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
510 GPUSubgroupBroadcastOpToROCDL>(converter);
511 patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
static Value getLaneId(RewriterBase &rewriter, Location loc)
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, Location loc, Value value, const LLVMTypeConverter &converter)
static llvm::ManagedStatic< PassManagerOptions > options
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Runtime
Potential runtimes for AMD GPU kernels.
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.