44 #include "llvm/Support/FormatVariadic.h"
46 #include "../GPUCommon/GPUOpsLowering.h"
47 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
48 #include "../GPUCommon/OpToFuncCallLowering.h"
51 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
52 #include "mlir/Conversion/Passes.h.inc"
60 bool canBeBare =
true;
61 for (
Type type : func.getArgumentTypes())
62 if (
auto memrefTy = dyn_cast<BaseMemRefType>(type))
68 const unsigned indexBitwidth) {
70 Value zero = rewriter.
create<arith::ConstantIntOp>(loc, 0, 32);
71 Value minus1 = rewriter.
create<arith::ConstantIntOp>(loc, -1, 32);
72 Value mbcntLo = rewriter.
create<ROCDL::MbcntLoOp>(loc, int32Type,
74 Value laneId = rewriter.
create<ROCDL::MbcntHiOp>(loc, int32Type,
79 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
80 "-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:"
81 "128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-"
89 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
97 Value zero = rewriter.
create<arith::ConstantIntOp>(loc, 0, 32);
98 Value minus1 = rewriter.
create<arith::ConstantIntOp>(loc, -1, 32);
105 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
106 if (indexBitwidth > 32) {
107 laneId = rewriter.
create<LLVM::SExtOp>(
109 }
else if (indexBitwidth < 32) {
110 laneId = rewriter.
create<LLVM::TruncOp>(
138 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
142 if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
144 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
148 Value width = adaptor.getWidth();
149 Value zero = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 0);
150 Value negwidth = rewriter.
create<LLVM::SubOp>(loc, int32Type, zero, width);
151 Value add = rewriter.
create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
152 Value widthOrZeroIfOutside =
153 rewriter.
create<LLVM::AndOp>(loc, int32Type, add, negwidth);
158 switch (op.getMode()) {
159 case gpu::ShuffleMode::XOR:
160 dstLane = rewriter.
create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
161 adaptor.getOffset());
163 case gpu::ShuffleMode::IDX:
164 dstLane = adaptor.getOffset();
169 Value isActiveSrcLane = rewriter.
create<LLVM::ICmpOp>(
170 loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
171 Value selectDstLane = rewriter.
create<LLVM::SelectOp>(loc, isActiveSrcLane,
173 Value two = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 2);
174 Value dwordAlignedDstLane =
175 rewriter.
create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
176 Value initShflValue = adaptor.getValue();
177 if (adaptor.getValue().getType().isF32()) {
179 rewriter.
create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
181 Value shflValue = rewriter.
create<ROCDL::DsBpermuteOp>(
182 loc, int32Type, dwordAlignedDstLane, initShflValue);
183 if (adaptor.getValue().getType().isF32()) {
184 shflValue = rewriter.
create<LLVM::BitcastOp>(
185 loc, adaptor.getValue().getType(), shflValue);
187 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
193 #include "GPUToROCDL.cpp.inc"
200 struct LowerGpuOpsToROCDLOpsPass
201 :
public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
202 LowerGpuOpsToROCDLOpsPass() =
default;
203 LowerGpuOpsToROCDLOpsPass(
const std::string &chipset,
unsigned indexBitwidth,
204 bool useBarePtrCallConv,
206 if (this->chipset.getNumOccurrences() == 0)
207 this->chipset = chipset;
208 if (this->indexBitwidth.getNumOccurrences() == 0)
209 this->indexBitwidth = indexBitwidth;
210 if (this->useBarePtrCallConv.getNumOccurrences() == 0)
211 this->useBarePtrCallConv = useBarePtrCallConv;
212 if (this->runtime.getNumOccurrences() == 0)
213 this->runtime = runtime;
216 void runOnOperation()
override {
217 gpu::GPUModuleOp m = getOperation();
220 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
221 LLVM::LLVMDialect::getDataLayoutAttrName());
222 if (!llvmDataLayout) {
224 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
227 for (
auto func : m.getOps<func::FuncOp>()) {
228 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
233 if (
failed(maybeChipset)) {
235 return signalPassFailure();
240 ctx,
DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
241 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
243 options.overrideIndexBitwidth(indexBitwidth);
245 if (useBarePtrCallConv) {
246 options.useBarePtrCallConv =
true;
248 m.walk([](gpu::GPUFuncOp func) ->
WalkResult {
255 "bare pointer calling convention requires all memrefs to "
256 "have static shape and use the identity map");
257 return signalPassFailure();
273 converter, [](gpu::AddressSpace space) {
275 case gpu::AddressSpace::Global:
277 case gpu::AddressSpace::Workgroup:
279 case gpu::AddressSpace::Private:
282 llvm_unreachable(
"unknown address space enum value");
301 auto reqdWorkGroupSizeAttrHelper =
302 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
303 auto flatWorkGroupSizeAttrHelper =
304 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
307 m.walk([&](LLVM::LLVMFuncOp op) {
308 if (
auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
309 op->
removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
310 reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
313 uint32_t flatSize = 1;
314 for (uint32_t size : blockSizes.asArrayRef()) {
317 StringAttr flatSizeAttr =
319 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
332 target.
addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
333 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
334 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
338 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
341 template <
typename OpTy>
354 populateWithGenerated(patterns);
357 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
358 converter, gpu::GPUFuncOp::getKnownBlockSizeAttrName());
360 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
361 converter, gpu::GPUFuncOp::getKnownGridSizeAttrName());
364 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
366 ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
370 ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
371 ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
372 ROCDL::ROCDLDialect::KernelAttrHelper(&converter.
getContext()).getName());
382 patterns.
add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
384 populateOpPatterns<math::AbsFOp>(converter, patterns,
"__ocml_fabs_f32",
386 populateOpPatterns<math::AtanOp>(converter, patterns,
"__ocml_atan_f32",
388 populateOpPatterns<math::Atan2Op>(converter, patterns,
"__ocml_atan2_f32",
390 populateOpPatterns<math::CbrtOp>(converter, patterns,
"__ocml_cbrt_f32",
392 populateOpPatterns<math::CeilOp>(converter, patterns,
"__ocml_ceil_f32",
394 populateOpPatterns<math::CosOp>(converter, patterns,
"__ocml_cos_f32",
396 populateOpPatterns<math::ExpOp>(converter, patterns,
"__ocml_exp_f32",
398 populateOpPatterns<math::Exp2Op>(converter, patterns,
"__ocml_exp2_f32",
400 populateOpPatterns<math::ExpM1Op>(converter, patterns,
"__ocml_expm1_f32",
402 populateOpPatterns<math::FloorOp>(converter, patterns,
"__ocml_floor_f32",
404 populateOpPatterns<arith::RemFOp>(converter, patterns,
"__ocml_fmod_f32",
406 populateOpPatterns<math::LogOp>(converter, patterns,
"__ocml_log_f32",
408 populateOpPatterns<math::Log10Op>(converter, patterns,
"__ocml_log10_f32",
410 populateOpPatterns<math::Log1pOp>(converter, patterns,
"__ocml_log1p_f32",
412 populateOpPatterns<math::Log2Op>(converter, patterns,
"__ocml_log2_f32",
414 populateOpPatterns<math::PowFOp>(converter, patterns,
"__ocml_pow_f32",
416 populateOpPatterns<math::RsqrtOp>(converter, patterns,
"__ocml_rsqrt_f32",
418 populateOpPatterns<math::SinOp>(converter, patterns,
"__ocml_sin_f32",
420 populateOpPatterns<math::SqrtOp>(converter, patterns,
"__ocml_sqrt_f32",
422 populateOpPatterns<math::TanhOp>(converter, patterns,
"__ocml_tanh_f32",
424 populateOpPatterns<math::TanOp>(converter, patterns,
"__ocml_tan_f32",
426 populateOpPatterns<math::ErfOp>(converter, patterns,
"__ocml_erf_f32",
430 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
432 unsigned indexBitwidth,
433 bool useBarePtrCallConv,
435 return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
436 chipset, indexBitwidth, useBarePtrCallConv, runtime);
static MLIRContext * getContext(OpFoldResult val)
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static constexpr StringLiteral amdgcnDataLayout
static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func)
Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, const unsigned indexBitwidth)
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
This class provides support for representing a failure result, or a valid value of type T.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Runtime
Potential runtimes for AMD GPU kernels.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
std::unique_ptr< OperationPass< gpu::GPUModuleOp > > createLowerGpuOpsToROCDLOpsPass(const std::string &chipset="gfx900", unsigned indexBitwidth=kDeriveIndexBitwidthFromDataLayout, bool useBarePtrCallConv=false, gpu::amd::Runtime runtime=gpu::amd::Runtime::Unknown)
Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
This class represents an efficient way to signal success or failure.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func depending on the element type tha...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
static FailureOr< Chipset > parse(StringRef name)