46#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
47#include "mlir/Conversion/Passes.h.inc"
59 auto indexBitwidthType =
62 if (indexBitwidth > intWidth) {
63 return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
65 if (indexBitwidth < intWidth) {
66 return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
74 bool canBeBare =
true;
75 for (
Type type :
func.getArgumentTypes())
76 if (
auto memrefTy = dyn_cast<BaseMemRefType>(type))
82 auto int32Type = IntegerType::get(rewriter.
getContext(), 32);
86 LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.
getUnitAttr());
88 LLVM::LLVMDialect::getRangeAttrName(),
89 LLVM::ConstantRangeAttr::get(rewriter.
getContext(), APInt::getZero(32),
92 LLVM::LLVMDialect::getRangeAttrName(),
93 LLVM::ConstantRangeAttr::get(rewriter.
getContext(), APInt::getZero(32),
95 Value mbcntLo = ROCDL::MbcntLoOp::create(
96 rewriter, loc, int32Type, minus1, zero, {},
99 Value laneId = ROCDL::MbcntHiOp::create(
100 rewriter, loc, int32Type, minus1, mbcntLo, {},
106 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
107 "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
109 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
110 "64-S32-A5-G1-ni:7:8:9";
117 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
118 ConversionPatternRewriter &rewriter)
const override {
131 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
132 if (indexBitwidth > 32) {
133 laneId = LLVM::SExtOp::create(
134 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
135 }
else if (indexBitwidth < 32) {
136 laneId = LLVM::TruncOp::create(
137 rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
139 rewriter.replaceOp(op, {laneId});
153 matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
154 ConversionPatternRewriter &rewriter)
const override {
155 LLVM::ConstantRangeAttr bounds =
nullptr;
157 if (
auto upperBoundAttr = op.getUpperBoundAttr()) {
158 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
159 32, isBeforeGfx10 ? 64 : 32,
160 op.getUpperBoundAttr().getInt() + 1);
162 Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
163 rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
165 *getTypeConverter());
166 rewriter.replaceOp(op, {wavefrontOp});
182 matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const override {
185 auto int32Type = rewriter.getI32Type();
190 LLVM::ConstantRangeAttr bounds;
191 if (
auto upperBoundAttr = op.getUpperBoundAttr())
192 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
194 upperBoundAttr.getInt());
195 subgroupId = ROCDL::WaveId::create(rewriter, loc, int32Type, bounds);
200 Value tidX = ROCDL::ThreadIdXOp::create(rewriter, loc, int32Type);
201 Value tidY = ROCDL::ThreadIdYOp::create(rewriter, loc, int32Type);
202 Value tidZ = ROCDL::ThreadIdZOp::create(rewriter, loc, int32Type);
203 Value dimX = ROCDL::BlockDimXOp::create(rewriter, loc, int32Type);
204 Value dimY = ROCDL::BlockDimYOp::create(rewriter, loc, int32Type);
209 LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
211 LLVM::MulOp::create(rewriter, loc, int32Type, dimY, tidZ, flags);
212 Value tidYPlusDimYxTidZ =
213 LLVM::AddOp::create(rewriter, loc, int32Type, tidY, dimYxTidZ, flags);
214 Value dimXxInner = LLVM::MulOp::create(rewriter, loc, int32Type, dimX,
215 tidYPlusDimYxTidZ, flags);
216 Value linearized = LLVM::AddOp::create(rewriter, loc, int32Type, tidX,
220 ROCDL::WavefrontSizeOp::create(rewriter, loc, int32Type);
221 subgroupId = LLVM::UDivOp::create(rewriter, loc, int32Type, linearized,
227 rewriter.replaceOp(op, subgroupId);
234static bool isSupportedReadLaneType(
Type type) {
236 if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
237 LLVM::LLVMPointerType>(type))
240 if (
auto intType = dyn_cast<IntegerType>(type))
241 return llvm::is_contained({16, 32, 64},
242 static_cast<int>(intType.getWidth()));
244 if (
auto vecType = dyn_cast<VectorType>(type)) {
245 Type elementType = vecType.getElementType();
249 if (vecType.getNumElements() == 2 &&
250 (isa<Float16Type, BFloat16Type>(elementType) ||
258struct GPUSubgroupBroadcastOpToROCDL
263 matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter)
const override {
265 Value src = adaptor.getSrc();
266 if (isSupportedReadLaneType(src.
getType())) {
267 Value result = createReadlaneOp(op, adaptor, rewriter, src);
268 rewriter.replaceOp(op,
result);
272 Type i32 = rewriter.getI32Type();
278 results.reserve(decomposed.size());
279 for (
Value v : decomposed)
280 results.emplace_back(createReadlaneOp(op, adaptor, rewriter, v));
283 rewriter.replaceOp(op,
result);
288 static Value createReadlaneOp(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter,
291 if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
292 return ROCDL::ReadlaneOp::create(rewriter, op.getLoc(), src.
getType(),
293 src, adaptor.getLane());
295 return ROCDL::ReadfirstlaneOp::create(rewriter, op.getLoc(),
321 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter)
const override {
324 Value initShflValue = adaptor.getValue();
328 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
329 Value width = adaptor.getWidth();
330 Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
331 Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
332 Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
333 Value widthOrZeroIfOutside =
334 LLVM::AndOp::create(rewriter, loc, int32Type,
add, negwidth);
337 switch (op.getMode()) {
338 case gpu::ShuffleMode::UP:
339 dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
340 adaptor.getOffset());
342 case gpu::ShuffleMode::DOWN:
343 dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
344 adaptor.getOffset());
346 case gpu::ShuffleMode::XOR:
347 dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
348 adaptor.getOffset());
350 case gpu::ShuffleMode::IDX:
351 dstLane = adaptor.getOffset();
354 Value isActiveSrcLane = LLVM::ICmpOp::create(
355 rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
356 Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
358 Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
359 Value dwordAlignedDstLane =
360 LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
365 for (
Value v : decomposed) {
366 Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
367 dwordAlignedDstLane, v);
368 swizzled.emplace_back(res);
372 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
385 matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor,
386 ConversionPatternRewriter &rewriter)
const override {
390 bool fenceGlobal =
false;
391 bool fenceLDS =
false;
392 std::optional<ArrayAttr> addrSpacesToFence = op.getAddressSpaces();
394 if (addrSpacesToFence) {
395 for (
auto spaceAttr :
396 addrSpacesToFence->getAsRange<gpu::AddressSpaceAttr>()) {
397 switch (spaceAttr.getValue()) {
398 case gpu::AddressSpace::Global:
401 case gpu::AddressSpace::Workgroup:
404 case gpu::AddressSpace::Private:
415 if (fenceLDS && !fenceGlobal) {
417 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
418 }
else if (fenceGlobal && !fenceLDS) {
419 mmra = rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
423 constexpr llvm::StringLiteral scope =
"workgroup";
425 bool emitFences = fenceGlobal || fenceLDS;
428 auto relFence = LLVM::FenceOp::create(
429 rewriter, loc, LLVM::AtomicOrdering::release, scope);
431 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
436 ROCDL::SBarrierOp::create(rewriter, loc);
438 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
439 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
443 auto acqFence = LLVM::FenceOp::create(
444 rewriter, loc, LLVM::AtomicOrdering::acquire, scope);
446 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(),
450 rewriter.eraseOp(op);
456#include "GPUToROCDL.cpp.inc"
463struct LowerGpuOpsToROCDLOpsPass final
468 Base::getDependentDialects(registry);
472 void runOnOperation()
override {
473 gpu::GPUModuleOp m = getOperation();
476 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
477 LLVM::LLVMDialect::getDataLayoutAttrName());
478 if (!llvmDataLayout) {
480 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
483 for (
auto func : m.getOps<func::FuncOp>()) {
484 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
489 if (failed(maybeChipset)) {
490 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
491 return signalPassFailure();
496 ctx,
DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
497 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
499 options.overrideIndexBitwidth(indexBitwidth);
501 if (useBarePtrCallConv) {
502 options.useBarePtrCallConv =
true;
511 "bare pointer calling convention requires all memrefs to "
512 "have static shape and use the identity map");
513 return signalPassFailure();
533 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
534 allowedDialects.end());
536 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
538 if (!allowedDialectsSet.empty() && !allowed)
541 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
547 <<
"dialect does not implement ConvertToLLVMPatternInterface: "
548 << dialect->getNamespace();
549 return signalPassFailure();
554 iface->populateConvertToLLVMConversionPatterns(
target, converter,
563 if (failed(applyPartialConversion(m,
target, std::move(llvmPatterns))))
565 auto *rocdlDialect =
getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
566 auto reqdWorkGroupSizeAttrHelper =
567 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
568 auto flatWorkGroupSizeAttrHelper =
569 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
572 m.walk([&](LLVM::LLVMFuncOp op) {
573 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
574 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
577 uint32_t flatSize = 1;
578 for (uint32_t size : blockSizes.asArrayRef()) {
581 StringAttr flatSizeAttr =
582 StringAttr::get(ctx, Twine(flatSize) +
"," + Twine(flatSize));
583 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
592 target.addIllegalOp<func::FuncOp>();
593 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
594 target.addLegalDialect<ROCDL::ROCDLDialect>();
595 target.addIllegalDialect<gpu::GPUDialect>();
596 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
597 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
598 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
600 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](
Operation *op) {
601 return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
604 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
618 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
619 converter, IndexKind::Block, IntrType::Id);
621 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
622 converter, IndexKind::Grid, IntrType::Id);
625 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(
626 converter, IndexKind::Block, IntrType::Dim);
628 gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
629 converter, IndexKind::Grid, IntrType::Dim);
634 ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
635 ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
636 rocdlDialect->getKernelAttrHelper().getName(),
637 rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName(),
641 }
else if (Runtime::OpenCL ==
runtime) {
648 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
649 GPUSubgroupBroadcastOpToROCDL>(converter);
650 patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
651 GPUBarrierOpLowering>(converter, chipset);
static Value getLaneId(RewriterBase &rewriter, Location loc)
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, Location loc, Value value, const LLVMTypeConverter &converter)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Runtime
Potential runtimes for AMD GPU kernels.
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, std::optional< amdgpu::Chipset > chipset)
Populate the given list with patterns that convert from Math to ROCDL calls.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.