34 #include "../GPUCommon/GPUOpsLowering.h"
35 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
36 #include "../GPUCommon/OpToFuncCallLowering.h"
40 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
41 #include "mlir/Conversion/Passes.h.inc"
49 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
51 case gpu::ShuffleMode::XOR:
52 return NVVM::ShflKind::bfly;
53 case gpu::ShuffleMode::UP:
54 return NVVM::ShflKind::up;
55 case gpu::ShuffleMode::DOWN:
56 return NVVM::ShflKind::down;
57 case gpu::ShuffleMode::IDX:
58 return NVVM::ShflKind::idx;
60 llvm_unreachable(
"unknown shuffle mode");
63 static std::optional<NVVM::ReduxKind>
64 convertReduxKind(gpu::AllReduceOperation mode) {
66 case gpu::AllReduceOperation::ADD:
67 return NVVM::ReduxKind::ADD;
68 case gpu::AllReduceOperation::MUL:
70 case gpu::AllReduceOperation::MINSI:
71 return NVVM::ReduxKind::MIN;
72 case gpu::AllReduceOperation::MINUI:
74 case gpu::AllReduceOperation::MINF:
75 return NVVM::ReduxKind::MIN;
76 case gpu::AllReduceOperation::MAXSI:
77 return NVVM::ReduxKind::MAX;
78 case gpu::AllReduceOperation::MAXUI:
80 case gpu::AllReduceOperation::MAXF:
81 return NVVM::ReduxKind::MAX;
82 case gpu::AllReduceOperation::AND:
83 return NVVM::ReduxKind::AND;
84 case gpu::AllReduceOperation::OR:
85 return NVVM::ReduxKind::OR;
86 case gpu::AllReduceOperation::XOR:
87 return NVVM::ReduxKind::XOR;
88 case gpu::AllReduceOperation::MINIMUMF:
89 case gpu::AllReduceOperation::MAXIMUMF:
97 struct GPUSubgroupReduceOpLowering
102 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
104 if (!op.getUniform())
106 op,
"cannot be lowered to redux as the op must be run "
107 "uniformly (entire subgroup).");
108 if (!op.getValue().getType().isInteger(32))
111 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
112 if (!mode.has_value())
114 op,
"unsupported reduction mode for redux");
118 Value offset = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
120 auto reduxOp = rewriter.
create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
121 mode.value(), offset);
123 rewriter.
replaceOp(op, reduxOp->getResult(0));
150 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
154 auto valueTy = adaptor.getValue().getType();
160 Value one = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 1);
161 Value minusOne = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
162 Value thirtyTwo = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 32);
163 Value numLeadInactiveLane = rewriter.
create<LLVM::SubOp>(
164 loc, int32Type, thirtyTwo, adaptor.getWidth());
166 Value activeMask = rewriter.
create<LLVM::LShrOp>(loc, int32Type, minusOne,
167 numLeadInactiveLane);
169 if (op.getMode() == gpu::ShuffleMode::UP) {
171 maskAndClamp = numLeadInactiveLane;
175 rewriter.
create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
178 auto returnValueAndIsValidAttr = rewriter.
getUnitAttr();
180 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
181 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
182 Value shflValue = rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 0);
183 Value isActiveSrcLane = rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 1);
185 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
194 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
201 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
202 if (indexBitwidth > 32) {
203 newOp = rewriter.
create<LLVM::SExtOp>(
205 }
else if (indexBitwidth < 32) {
206 newOp = rewriter.
create<LLVM::TruncOp>(
215 #include "GPUToNVVM.cpp.inc"
222 struct LowerGpuOpsToNVVMOpsPass
223 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
226 void runOnOperation()
override {
227 gpu::GPUModuleOp m = getOperation();
230 for (
auto func : m.getOps<func::FuncOp>()) {
231 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
238 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
240 options.overrideIndexBitwidth(indexBitwidth);
241 options.useBarePtrCallConv = useBarePtrCallConv;
250 return signalPassFailure();
259 converter, [](gpu::AddressSpace space) ->
unsigned {
261 case gpu::AddressSpace::Global:
262 return static_cast<unsigned>(
264 case gpu::AddressSpace::Workgroup:
265 return static_cast<unsigned>(
267 case gpu::AddressSpace::Private:
270 llvm_unreachable(
"unknown address space enum value");
301 target.
addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
302 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
303 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
307 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
310 template <
typename OpTy>
320 patterns.
add<GPUSubgroupReduceOpLowering>(converter);
325 populateWithGenerated(patterns);
329 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
331 NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
333 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
335 NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
337 NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
339 NVVM::GridDimYOp, NVVM::GridDimZOp>,
353 NVVM::NVVMDialect::getKernelFuncAttrName()));
355 populateOpPatterns<math::AbsFOp>(converter, patterns,
"__nv_fabsf",
357 populateOpPatterns<math::AtanOp>(converter, patterns,
"__nv_atanf",
359 populateOpPatterns<math::Atan2Op>(converter, patterns,
"__nv_atan2f",
361 populateOpPatterns<math::CbrtOp>(converter, patterns,
"__nv_cbrtf",
363 populateOpPatterns<math::CeilOp>(converter, patterns,
"__nv_ceilf",
365 populateOpPatterns<math::CosOp>(converter, patterns,
"__nv_cosf",
"__nv_cos");
366 populateOpPatterns<math::ExpOp>(converter, patterns,
"__nv_expf",
"__nv_exp");
367 populateOpPatterns<math::Exp2Op>(converter, patterns,
"__nv_exp2f",
369 populateOpPatterns<math::ExpM1Op>(converter, patterns,
"__nv_expm1f",
371 populateOpPatterns<math::FloorOp>(converter, patterns,
"__nv_floorf",
373 populateOpPatterns<arith::RemFOp>(converter, patterns,
"__nv_fmodf",
375 populateOpPatterns<math::LogOp>(converter, patterns,
"__nv_logf",
"__nv_log");
376 populateOpPatterns<math::Log1pOp>(converter, patterns,
"__nv_log1pf",
378 populateOpPatterns<math::Log10Op>(converter, patterns,
"__nv_log10f",
380 populateOpPatterns<math::Log2Op>(converter, patterns,
"__nv_log2f",
382 populateOpPatterns<math::PowFOp>(converter, patterns,
"__nv_powf",
384 populateOpPatterns<math::RsqrtOp>(converter, patterns,
"__nv_rsqrtf",
386 populateOpPatterns<math::SinOp>(converter, patterns,
"__nv_sinf",
"__nv_sin");
387 populateOpPatterns<math::SqrtOp>(converter, patterns,
"__nv_sqrtf",
389 populateOpPatterns<math::TanhOp>(converter, patterns,
"__nv_tanhf",
391 populateOpPatterns<math::TanOp>(converter, patterns,
"__nv_tanf",
"__nv_tan");
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func)
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
@ kGlobalMemorySpace
Global memory space identifier.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
This class represents an efficient way to signal success or failure.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func depending on the element type tha...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.