44 #include "llvm/Support/FormatVariadic.h"
46 #include "../GPUCommon/GPUOpsLowering.h"
47 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
48 #include "../GPUCommon/OpToFuncCallLowering.h"
51 #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
52 #include "mlir/Conversion/Passes.h.inc"
60 bool canBeBare =
true;
61 for (
Type type : func.getArgumentTypes())
62 if (
auto memrefTy = dyn_cast<BaseMemRefType>(type))
68 const unsigned indexBitwidth) {
72 Value mbcntLo = rewriter.
create<ROCDL::MbcntLoOp>(loc, int32Type,
74 Value laneId = rewriter.
create<ROCDL::MbcntHiOp>(loc, int32Type,
84 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
100 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
101 if (indexBitwidth > 32) {
102 laneId = rewriter.
create<LLVM::SExtOp>(
104 }
else if (indexBitwidth < 32) {
105 laneId = rewriter.
create<LLVM::TruncOp>(
133 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
137 if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
139 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
143 Value width = adaptor.getWidth();
144 Value zero = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 0);
145 Value negwidth = rewriter.
create<LLVM::SubOp>(loc, int32Type, zero, width);
146 Value add = rewriter.
create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
147 Value widthOrZeroIfOutside =
148 rewriter.
create<LLVM::AndOp>(loc, int32Type, add, negwidth);
153 switch (op.getMode()) {
154 case gpu::ShuffleMode::XOR:
155 dstLane = rewriter.
create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
156 adaptor.getOffset());
158 case gpu::ShuffleMode::IDX:
159 dstLane = adaptor.getOffset();
164 Value isActiveSrcLane = rewriter.
create<LLVM::ICmpOp>(
165 loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
166 Value selectDstLane = rewriter.
create<LLVM::SelectOp>(loc, isActiveSrcLane,
168 Value two = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 2);
169 Value dwordAlignedDstLane =
170 rewriter.
create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
171 Value initShflValue = adaptor.getValue();
172 if (adaptor.getValue().getType().isF32()) {
174 rewriter.
create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
176 Value shflValue = rewriter.
create<ROCDL::DsBpermuteOp>(
177 loc, int32Type, dwordAlignedDstLane, initShflValue);
178 if (adaptor.getValue().getType().isF32()) {
179 shflValue = rewriter.
create<LLVM::BitcastOp>(
180 loc, adaptor.getValue().getType(), shflValue);
182 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
188 #include "GPUToROCDL.cpp.inc"
195 struct LowerGpuOpsToROCDLOpsPass
196 :
public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
197 LowerGpuOpsToROCDLOpsPass() =
default;
198 LowerGpuOpsToROCDLOpsPass(
const std::string &chipset,
unsigned indexBitwidth,
199 bool useBarePtrCallConv,
201 if (this->chipset.getNumOccurrences() == 0)
202 this->chipset = chipset;
203 if (this->indexBitwidth.getNumOccurrences() == 0)
204 this->indexBitwidth = indexBitwidth;
205 if (this->useBarePtrCallConv.getNumOccurrences() == 0)
206 this->useBarePtrCallConv = useBarePtrCallConv;
207 if (this->runtime.getNumOccurrences() == 0)
208 this->runtime = runtime;
211 void runOnOperation()
override {
212 gpu::GPUModuleOp m = getOperation();
216 for (
auto func : m.getOps<func::FuncOp>()) {
217 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
222 if (
failed(maybeChipset)) {
224 return signalPassFailure();
229 ctx,
DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
231 options.overrideIndexBitwidth(indexBitwidth);
233 if (useBarePtrCallConv) {
234 options.useBarePtrCallConv =
true;
236 m.walk([](gpu::GPUFuncOp func) ->
WalkResult {
243 "bare pointer calling convention requires all memrefs to "
244 "have static shape and use the identity map");
245 return signalPassFailure();
261 converter, [](gpu::AddressSpace space) {
263 case gpu::AddressSpace::Global:
265 case gpu::AddressSpace::Workgroup:
267 case gpu::AddressSpace::Private:
270 llvm_unreachable(
"unknown address space enum value");
291 m.walk([ctx](LLVM::LLVMFuncOp op) {
292 if (
auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
293 op->
removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
294 op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
298 uint32_t flatSize = 1;
299 for (uint32_t size : blockSizes.asArrayRef()) {
302 StringAttr flatSizeAttr =
304 op->
setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
318 target.
addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
319 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
320 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
324 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
327 template <
typename OpTy>
340 populateWithGenerated(patterns);
343 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
344 converter, gpu::GPUFuncOp::getKnownBlockSizeAttrName());
346 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
347 converter, gpu::GPUFuncOp::getKnownGridSizeAttrName());
350 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
352 ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
356 ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
357 ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
359 ROCDL::ROCDLDialect::getKernelFuncAttrName()));
367 patterns.
add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
369 populateOpPatterns<math::AbsFOp>(converter, patterns,
"__ocml_fabs_f32",
371 populateOpPatterns<math::AtanOp>(converter, patterns,
"__ocml_atan_f32",
373 populateOpPatterns<math::Atan2Op>(converter, patterns,
"__ocml_atan2_f32",
375 populateOpPatterns<math::CbrtOp>(converter, patterns,
"__ocml_cbrt_f32",
377 populateOpPatterns<math::CeilOp>(converter, patterns,
"__ocml_ceil_f32",
379 populateOpPatterns<math::CosOp>(converter, patterns,
"__ocml_cos_f32",
381 populateOpPatterns<math::ExpOp>(converter, patterns,
"__ocml_exp_f32",
383 populateOpPatterns<math::Exp2Op>(converter, patterns,
"__ocml_exp2_f32",
385 populateOpPatterns<math::ExpM1Op>(converter, patterns,
"__ocml_expm1_f32",
387 populateOpPatterns<math::FloorOp>(converter, patterns,
"__ocml_floor_f32",
389 populateOpPatterns<arith::RemFOp>(converter, patterns,
"__ocml_fmod_f32",
391 populateOpPatterns<math::LogOp>(converter, patterns,
"__ocml_log_f32",
393 populateOpPatterns<math::Log10Op>(converter, patterns,
"__ocml_log10_f32",
395 populateOpPatterns<math::Log1pOp>(converter, patterns,
"__ocml_log1p_f32",
397 populateOpPatterns<math::Log2Op>(converter, patterns,
"__ocml_log2_f32",
399 populateOpPatterns<math::PowFOp>(converter, patterns,
"__ocml_pow_f32",
401 populateOpPatterns<math::RsqrtOp>(converter, patterns,
"__ocml_rsqrt_f32",
403 populateOpPatterns<math::SinOp>(converter, patterns,
"__ocml_sin_f32",
405 populateOpPatterns<math::SqrtOp>(converter, patterns,
"__ocml_sqrt_f32",
407 populateOpPatterns<math::TanhOp>(converter, patterns,
"__ocml_tanh_f32",
409 populateOpPatterns<math::TanOp>(converter, patterns,
"__ocml_tan_f32",
411 populateOpPatterns<math::ErfOp>(converter, patterns,
"__ocml_erf_f32",
415 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
417 unsigned indexBitwidth,
418 bool useBarePtrCallConv,
420 return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
421 chipset, indexBitwidth, useBarePtrCallConv, runtime);
static MLIRContext * getContext(OpFoldResult val)
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func)
Returns true if the given gpu.func can be safely called using the bare pointer calling convention.
static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func)
Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, const unsigned indexBitwidth)
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
This class provides support for representing a failure result, or a valid value of type T.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Runtime
Potential runtimes for AMD GPU kernels.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void configureGpuToROCDLConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to ROCDL.
std::unique_ptr< OperationPass< gpu::GPUModuleOp > > createLowerGpuOpsToROCDLOpsPass(const std::string &chipset="gfx900", unsigned indexBitwidth=kDeriveIndexBitwidthFromDataLayout, bool useBarePtrCallConv=false, gpu::amd::Runtime runtime=gpu::amd::Runtime::Unknown)
Creates a pass that lowers GPU dialect operations to ROCDL counterparts.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
The lowering of gpu.printf to a call to HIP hostcalls.
The lowering of gpu.printf to a call to an external printf() function.
This class represents an efficient way to signal success or failure.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func depending on the element type tha...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
static FailureOr< Chipset > parse(StringRef name)