35 #include "../GPUCommon/GPUOpsLowering.h"
36 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
37 #include "../GPUCommon/OpToFuncCallLowering.h"
41 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
42 #include "mlir/Conversion/Passes.h.inc"
50 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
52 case gpu::ShuffleMode::XOR:
53 return NVVM::ShflKind::bfly;
54 case gpu::ShuffleMode::UP:
55 return NVVM::ShflKind::up;
56 case gpu::ShuffleMode::DOWN:
57 return NVVM::ShflKind::down;
58 case gpu::ShuffleMode::IDX:
59 return NVVM::ShflKind::idx;
61 llvm_unreachable(
"unknown shuffle mode");
64 static std::optional<NVVM::ReduxKind>
65 convertReduxKind(gpu::AllReduceOperation mode) {
67 case gpu::AllReduceOperation::ADD:
68 return NVVM::ReduxKind::ADD;
69 case gpu::AllReduceOperation::MUL:
71 case gpu::AllReduceOperation::MINSI:
72 return NVVM::ReduxKind::MIN;
75 case gpu::AllReduceOperation::MINNUMF:
76 return NVVM::ReduxKind::MIN;
77 case gpu::AllReduceOperation::MAXSI:
78 return NVVM::ReduxKind::MAX;
79 case gpu::AllReduceOperation::MAXUI:
81 case gpu::AllReduceOperation::MAXNUMF:
82 return NVVM::ReduxKind::MAX;
83 case gpu::AllReduceOperation::AND:
84 return NVVM::ReduxKind::AND;
85 case gpu::AllReduceOperation::OR:
86 return NVVM::ReduxKind::OR;
87 case gpu::AllReduceOperation::XOR:
88 return NVVM::ReduxKind::XOR;
89 case gpu::AllReduceOperation::MINIMUMF:
90 case gpu::AllReduceOperation::MAXIMUMF:
98 struct GPUSubgroupReduceOpLowering
103 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
105 if (!op.getUniform())
107 op,
"cannot be lowered to redux as the op must be run "
108 "uniformly (entire subgroup).");
109 if (!op.getValue().getType().isInteger(32))
112 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
113 if (!mode.has_value())
115 op,
"unsupported reduction mode for redux");
119 Value offset = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
121 auto reduxOp = rewriter.
create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
122 mode.value(), offset);
124 rewriter.
replaceOp(op, reduxOp->getResult(0));
151 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
155 auto valueTy = adaptor.getValue().getType();
159 Value one = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 1);
160 Value minusOne = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
161 Value thirtyTwo = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 32);
162 Value numLeadInactiveLane = rewriter.
create<LLVM::SubOp>(
163 loc, int32Type, thirtyTwo, adaptor.getWidth());
165 Value activeMask = rewriter.
create<LLVM::LShrOp>(loc, int32Type, minusOne,
166 numLeadInactiveLane);
168 if (op.getMode() == gpu::ShuffleMode::UP) {
170 maskAndClamp = numLeadInactiveLane;
174 rewriter.
create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
178 UnitAttr returnValueAndIsValidAttr =
nullptr;
179 Type resultTy = valueTy;
181 returnValueAndIsValidAttr = rewriter.
getUnitAttr();
186 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
187 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
189 Value shflValue = rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 0);
190 Value isActiveSrcLane =
191 rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 1);
192 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
204 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
211 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
212 if (indexBitwidth > 32) {
213 newOp = rewriter.
create<LLVM::SExtOp>(
215 }
else if (indexBitwidth < 32) {
216 newOp = rewriter.
create<LLVM::TruncOp>(
225 #include "GPUToNVVM.cpp.inc"
232 struct LowerGpuOpsToNVVMOpsPass
233 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
236 void runOnOperation()
override {
237 gpu::GPUModuleOp m = getOperation();
240 for (
auto func : m.getOps<func::FuncOp>()) {
241 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
248 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
250 options.overrideIndexBitwidth(indexBitwidth);
251 options.useBarePtrCallConv = useBarePtrCallConv;
260 return signalPassFailure();
269 converter, [](gpu::AddressSpace space) ->
unsigned {
271 case gpu::AddressSpace::Global:
272 return static_cast<unsigned>(
274 case gpu::AddressSpace::Workgroup:
275 return static_cast<unsigned>(
277 case gpu::AddressSpace::Private:
280 llvm_unreachable(
"unknown address space enum value");
312 target.
addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
318 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
321 template <
typename OpTy>
331 patterns.
add<GPUSubgroupReduceOpLowering>(converter);
336 populateWithGenerated(patterns);
340 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
342 NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
344 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
346 NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
348 NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
350 NVVM::GridDimYOp, NVVM::GridDimZOp>,
364 NVVM::NVVMDialect::getKernelFuncAttrName()),
366 NVVM::NVVMDialect::getMaxntidAttrName()));
368 populateOpPatterns<math::AbsFOp>(converter, patterns,
"__nv_fabsf",
370 populateOpPatterns<math::AtanOp>(converter, patterns,
"__nv_atanf",
372 populateOpPatterns<math::Atan2Op>(converter, patterns,
"__nv_atan2f",
374 populateOpPatterns<math::CbrtOp>(converter, patterns,
"__nv_cbrtf",
376 populateOpPatterns<math::CeilOp>(converter, patterns,
"__nv_ceilf",
378 populateOpPatterns<math::CosOp>(converter, patterns,
"__nv_cosf",
"__nv_cos");
379 populateOpPatterns<math::ErfOp>(converter, patterns,
"__nv_erff",
"__nv_erf");
380 populateOpPatterns<math::ExpOp>(converter, patterns,
"__nv_expf",
"__nv_exp");
381 populateOpPatterns<math::Exp2Op>(converter, patterns,
"__nv_exp2f",
383 populateOpPatterns<math::ExpM1Op>(converter, patterns,
"__nv_expm1f",
385 populateOpPatterns<math::FloorOp>(converter, patterns,
"__nv_floorf",
387 populateOpPatterns<arith::RemFOp>(converter, patterns,
"__nv_fmodf",
389 populateOpPatterns<math::LogOp>(converter, patterns,
"__nv_logf",
"__nv_log");
390 populateOpPatterns<math::Log1pOp>(converter, patterns,
"__nv_log1pf",
392 populateOpPatterns<math::Log10Op>(converter, patterns,
"__nv_log10f",
394 populateOpPatterns<math::Log2Op>(converter, patterns,
"__nv_log2f",
396 populateOpPatterns<math::PowFOp>(converter, patterns,
"__nv_powf",
398 populateOpPatterns<math::RsqrtOp>(converter, patterns,
"__nv_rsqrtf",
400 populateOpPatterns<math::SinOp>(converter, patterns,
"__nv_sinf",
"__nv_sin");
401 populateOpPatterns<math::SqrtOp>(converter, patterns,
"__nv_sqrtf",
403 populateOpPatterns<math::TanhOp>(converter, patterns,
"__nv_tanhf",
405 populateOpPatterns<math::TanOp>(converter, patterns,
"__nv_tanf",
"__nv_tan");
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func)
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
@ kGlobalMemorySpace
Global memory space identifier.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
This class represents an efficient way to signal success or failure.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func depending on the element type tha...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.