35 #include "../GPUCommon/GPUOpsLowering.h"
36 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
37 #include "../GPUCommon/OpToFuncCallLowering.h"
41 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
42 #include "mlir/Conversion/Passes.h.inc"
50 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
52 case gpu::ShuffleMode::XOR:
53 return NVVM::ShflKind::bfly;
54 case gpu::ShuffleMode::UP:
55 return NVVM::ShflKind::up;
56 case gpu::ShuffleMode::DOWN:
57 return NVVM::ShflKind::down;
58 case gpu::ShuffleMode::IDX:
59 return NVVM::ShflKind::idx;
61 llvm_unreachable(
"unknown shuffle mode");
64 static std::optional<NVVM::ReduxKind>
65 convertReduxKind(gpu::AllReduceOperation mode) {
67 case gpu::AllReduceOperation::ADD:
68 return NVVM::ReduxKind::ADD;
69 case gpu::AllReduceOperation::MUL:
71 case gpu::AllReduceOperation::MINSI:
72 return NVVM::ReduxKind::MIN;
75 case gpu::AllReduceOperation::MINNUMF:
76 return NVVM::ReduxKind::MIN;
77 case gpu::AllReduceOperation::MAXSI:
78 return NVVM::ReduxKind::MAX;
79 case gpu::AllReduceOperation::MAXUI:
81 case gpu::AllReduceOperation::MAXNUMF:
82 return NVVM::ReduxKind::MAX;
83 case gpu::AllReduceOperation::AND:
84 return NVVM::ReduxKind::AND;
85 case gpu::AllReduceOperation::OR:
86 return NVVM::ReduxKind::OR;
87 case gpu::AllReduceOperation::XOR:
88 return NVVM::ReduxKind::XOR;
89 case gpu::AllReduceOperation::MINIMUMF:
90 case gpu::AllReduceOperation::MAXIMUMF:
98 struct GPUSubgroupReduceOpLowering
103 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
105 if (op.getClusterSize())
107 op,
"lowering for clustered reduce not implemented");
109 if (!op.getUniform())
111 op,
"cannot be lowered to redux as the op must be run "
112 "uniformly (entire subgroup).");
113 if (!op.getValue().getType().isInteger(32))
116 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
117 if (!mode.has_value())
119 op,
"unsupported reduction mode for redux");
123 Value offset = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
125 auto reduxOp = rewriter.
create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
126 mode.value(), offset);
128 rewriter.
replaceOp(op, reduxOp->getResult(0));
155 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
159 auto valueTy = adaptor.getValue().getType();
163 Value one = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 1);
164 Value minusOne = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
165 Value thirtyTwo = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 32);
166 Value numLeadInactiveLane = rewriter.
create<LLVM::SubOp>(
167 loc, int32Type, thirtyTwo, adaptor.getWidth());
169 Value activeMask = rewriter.
create<LLVM::LShrOp>(loc, int32Type, minusOne,
170 numLeadInactiveLane);
172 if (op.getMode() == gpu::ShuffleMode::UP) {
174 maskAndClamp = numLeadInactiveLane;
178 rewriter.
create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
182 UnitAttr returnValueAndIsValidAttr =
nullptr;
183 Type resultTy = valueTy;
185 returnValueAndIsValidAttr = rewriter.
getUnitAttr();
190 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
191 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
193 Value shflValue = rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 0);
194 Value isActiveSrcLane =
195 rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 1);
196 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
208 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
215 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
216 if (indexBitwidth > 32) {
217 newOp = rewriter.
create<LLVM::SExtOp>(
219 }
else if (indexBitwidth < 32) {
220 newOp = rewriter.
create<LLVM::TruncOp>(
229 #include "GPUToNVVM.cpp.inc"
236 struct LowerGpuOpsToNVVMOpsPass
237 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
240 void runOnOperation()
override {
241 gpu::GPUModuleOp m = getOperation();
244 for (
auto func : m.getOps<func::FuncOp>()) {
245 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
252 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
254 options.overrideIndexBitwidth(indexBitwidth);
255 options.useBarePtrCallConv = useBarePtrCallConv;
264 return signalPassFailure();
273 converter, [](gpu::AddressSpace space) ->
unsigned {
275 case gpu::AddressSpace::Global:
276 return static_cast<unsigned>(
278 case gpu::AddressSpace::Workgroup:
279 return static_cast<unsigned>(
281 case gpu::AddressSpace::Private:
284 llvm_unreachable(
"unknown address space enum value");
316 target.
addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
317 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
318 LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
319 LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
320 LLVM::SinOp, LLVM::SqrtOp>();
323 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
326 template <
typename OpTy>
330 StringRef f32ApproxFunc =
"") {
338 patterns.
add<GPUSubgroupReduceOpLowering>(converter);
343 populateWithGenerated(patterns);
347 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
349 NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
351 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
353 NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
355 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
356 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>,
358 NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
360 NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
362 NVVM::GridDimYOp, NVVM::GridDimZOp>,
378 NVVM::NVVMDialect::getKernelFuncAttrName()),
380 NVVM::NVVMDialect::getMaxntidAttrName())});
382 populateOpPatterns<arith::RemFOp>(converter, patterns,
"__nv_fmodf",
384 populateOpPatterns<math::AbsFOp>(converter, patterns,
"__nv_fabsf",
386 populateOpPatterns<math::AcosOp>(converter, patterns,
"__nv_acosf",
388 populateOpPatterns<math::AcoshOp>(converter, patterns,
"__nv_acoshf",
390 populateOpPatterns<math::AsinOp>(converter, patterns,
"__nv_asinf",
392 populateOpPatterns<math::AsinhOp>(converter, patterns,
"__nv_asinhf",
394 populateOpPatterns<math::AtanOp>(converter, patterns,
"__nv_atanf",
396 populateOpPatterns<math::Atan2Op>(converter, patterns,
"__nv_atan2f",
398 populateOpPatterns<math::AtanhOp>(converter, patterns,
"__nv_atanhf",
400 populateOpPatterns<math::CbrtOp>(converter, patterns,
"__nv_cbrtf",
402 populateOpPatterns<math::CeilOp>(converter, patterns,
"__nv_ceilf",
404 populateOpPatterns<math::CopySignOp>(converter, patterns,
"__nv_copysignf",
406 populateOpPatterns<math::CosOp>(converter, patterns,
"__nv_cosf",
"__nv_cos",
408 populateOpPatterns<math::CoshOp>(converter, patterns,
"__nv_coshf",
410 populateOpPatterns<math::ErfOp>(converter, patterns,
"__nv_erff",
"__nv_erf");
411 populateOpPatterns<math::ExpOp>(converter, patterns,
"__nv_expf",
"__nv_exp",
413 populateOpPatterns<math::Exp2Op>(converter, patterns,
"__nv_exp2f",
415 populateOpPatterns<math::ExpM1Op>(converter, patterns,
"__nv_expm1f",
417 populateOpPatterns<math::FloorOp>(converter, patterns,
"__nv_floorf",
419 populateOpPatterns<math::FmaOp>(converter, patterns,
"__nv_fmaf",
"__nv_fma");
420 populateOpPatterns<math::LogOp>(converter, patterns,
"__nv_logf",
"__nv_log",
422 populateOpPatterns<math::Log10Op>(converter, patterns,
"__nv_log10f",
423 "__nv_log10",
"__nv_fast_log10f");
424 populateOpPatterns<math::Log1pOp>(converter, patterns,
"__nv_log1pf",
426 populateOpPatterns<math::Log2Op>(converter, patterns,
"__nv_log2f",
427 "__nv_log2",
"__nv_fast_log2f");
428 populateOpPatterns<math::PowFOp>(converter, patterns,
"__nv_powf",
"__nv_pow",
430 populateOpPatterns<math::RoundOp>(converter, patterns,
"__nv_roundf",
432 populateOpPatterns<math::RoundEvenOp>(converter, patterns,
"__nv_rintf",
434 populateOpPatterns<math::RsqrtOp>(converter, patterns,
"__nv_rsqrtf",
436 populateOpPatterns<math::SinOp>(converter, patterns,
"__nv_sinf",
"__nv_sin",
438 populateOpPatterns<math::SinhOp>(converter, patterns,
"__nv_sinhf",
440 populateOpPatterns<math::SqrtOp>(converter, patterns,
"__nv_sqrtf",
442 populateOpPatterns<math::TanOp>(converter, patterns,
"__nv_tanf",
"__nv_tan",
444 populateOpPatterns<math::TanhOp>(converter, patterns,
"__nv_tanhf",
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="")
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
@ kGlobalMemorySpace
Global memory space identifier.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc depending on the...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.