36 #include "../GPUCommon/GPUOpsLowering.h"
37 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
38 #include "../GPUCommon/OpToFuncCallLowering.h"
42 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
43 #include "mlir/Conversion/Passes.h.inc"
51 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
53 case gpu::ShuffleMode::XOR:
54 return NVVM::ShflKind::bfly;
55 case gpu::ShuffleMode::UP:
56 return NVVM::ShflKind::up;
57 case gpu::ShuffleMode::DOWN:
58 return NVVM::ShflKind::down;
59 case gpu::ShuffleMode::IDX:
60 return NVVM::ShflKind::idx;
62 llvm_unreachable(
"unknown shuffle mode");
65 static std::optional<NVVM::ReduxKind>
66 convertReduxKind(gpu::AllReduceOperation mode) {
68 case gpu::AllReduceOperation::ADD:
69 return NVVM::ReduxKind::ADD;
70 case gpu::AllReduceOperation::MUL:
72 case gpu::AllReduceOperation::MINSI:
73 return NVVM::ReduxKind::MIN;
76 case gpu::AllReduceOperation::MINNUMF:
77 return NVVM::ReduxKind::MIN;
78 case gpu::AllReduceOperation::MAXSI:
79 return NVVM::ReduxKind::MAX;
80 case gpu::AllReduceOperation::MAXUI:
82 case gpu::AllReduceOperation::MAXNUMF:
83 return NVVM::ReduxKind::MAX;
84 case gpu::AllReduceOperation::AND:
85 return NVVM::ReduxKind::AND;
86 case gpu::AllReduceOperation::OR:
87 return NVVM::ReduxKind::OR;
88 case gpu::AllReduceOperation::XOR:
89 return NVVM::ReduxKind::XOR;
90 case gpu::AllReduceOperation::MINIMUMF:
91 case gpu::AllReduceOperation::MAXIMUMF:
99 struct GPUSubgroupReduceOpLowering
104 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
106 if (op.getClusterSize())
108 op,
"lowering for clustered reduce not implemented");
110 if (!op.getUniform())
112 op,
"cannot be lowered to redux as the op must be run "
113 "uniformly (entire subgroup).");
114 if (!op.getValue().getType().isInteger(32))
117 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
118 if (!mode.has_value())
120 op,
"unsupported reduction mode for redux");
124 Value offset = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
126 auto reduxOp = rewriter.
create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
127 mode.value(), offset);
129 rewriter.
replaceOp(op, reduxOp->getResult(0));
156 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
160 auto valueTy = adaptor.getValue().getType();
164 Value one = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 1);
165 Value minusOne = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
166 Value thirtyTwo = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 32);
167 Value numLeadInactiveLane = rewriter.
create<LLVM::SubOp>(
168 loc, int32Type, thirtyTwo, adaptor.getWidth());
170 Value activeMask = rewriter.
create<LLVM::LShrOp>(loc, int32Type, minusOne,
171 numLeadInactiveLane);
173 if (op.getMode() == gpu::ShuffleMode::UP) {
175 maskAndClamp = numLeadInactiveLane;
179 rewriter.
create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
183 UnitAttr returnValueAndIsValidAttr =
nullptr;
184 Type resultTy = valueTy;
186 returnValueAndIsValidAttr = rewriter.
getUnitAttr();
191 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
192 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
194 Value shflValue = rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 0);
195 Value isActiveSrcLane =
196 rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 1);
197 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
209 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
211 auto loc = op->getLoc();
213 LLVM::ConstantRangeAttr bounds =
nullptr;
214 if (std::optional<APInt> upperBound = op.getUpperBound())
215 bounds = rewriter.
getAttr<LLVM::ConstantRangeAttr>(
216 32, 0, upperBound->getZExtValue());
218 bounds = rewriter.
getAttr<LLVM::ConstantRangeAttr>(
224 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
225 if (indexBitwidth > 32) {
226 newOp = rewriter.
create<LLVM::SExtOp>(
228 }
else if (indexBitwidth < 32) {
229 newOp = rewriter.
create<LLVM::TruncOp>(
238 #include "GPUToNVVM.cpp.inc"
245 struct LowerGpuOpsToNVVMOpsPass
246 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
249 void runOnOperation()
override {
250 gpu::GPUModuleOp m = getOperation();
253 for (
auto func : m.getOps<func::FuncOp>()) {
254 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
261 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
263 options.overrideIndexBitwidth(indexBitwidth);
264 options.useBarePtrCallConv = useBarePtrCallConv;
273 return signalPassFailure();
282 converter, [](gpu::AddressSpace space) ->
unsigned {
284 case gpu::AddressSpace::Global:
285 return static_cast<unsigned>(
287 case gpu::AddressSpace::Workgroup:
288 return static_cast<unsigned>(
290 case gpu::AddressSpace::Private:
293 llvm_unreachable(
"unknown address space enum value");
325 target.
addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
326 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
327 LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
328 LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
329 LLVM::SinOp, LLVM::SqrtOp>();
332 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
335 template <
typename OpTy>
338 StringRef f64Func, StringRef f32ApproxFunc =
"",
339 StringRef f16Func =
"") {
342 f32ApproxFunc, f16Func);
347 patterns.
add<GPUSubgroupReduceOpLowering>(converter);
354 populateWithGenerated(patterns);
358 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
359 converter, IndexKind::Block, IntrType::Id);
362 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
363 converter, IndexKind::Block, IntrType::Dim);
366 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
367 converter, IndexKind::Other, IntrType::Id);
369 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
370 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
372 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
373 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
374 converter, IndexKind::Other, IntrType::Id);
376 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
377 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
378 converter, IndexKind::Other, IntrType::Dim);
380 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
381 converter, IndexKind::Grid, IntrType::Id);
383 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
384 converter, IndexKind::Grid, IntrType::Dim);
401 NVVM::NVVMDialect::getKernelFuncAttrName()),
403 NVVM::NVVMDialect::getMaxntidAttrName())});
405 populateOpPatterns<arith::RemFOp>(converter, patterns,
"__nv_fmodf",
407 populateOpPatterns<math::AbsFOp>(converter, patterns,
"__nv_fabsf",
409 populateOpPatterns<math::AcosOp>(converter, patterns,
"__nv_acosf",
411 populateOpPatterns<math::AcoshOp>(converter, patterns,
"__nv_acoshf",
413 populateOpPatterns<math::AsinOp>(converter, patterns,
"__nv_asinf",
415 populateOpPatterns<math::AsinhOp>(converter, patterns,
"__nv_asinhf",
417 populateOpPatterns<math::AtanOp>(converter, patterns,
"__nv_atanf",
419 populateOpPatterns<math::Atan2Op>(converter, patterns,
"__nv_atan2f",
421 populateOpPatterns<math::AtanhOp>(converter, patterns,
"__nv_atanhf",
423 populateOpPatterns<math::CbrtOp>(converter, patterns,
"__nv_cbrtf",
425 populateOpPatterns<math::CeilOp>(converter, patterns,
"__nv_ceilf",
427 populateOpPatterns<math::CopySignOp>(converter, patterns,
"__nv_copysignf",
429 populateOpPatterns<math::CosOp>(converter, patterns,
"__nv_cosf",
"__nv_cos",
431 populateOpPatterns<math::CoshOp>(converter, patterns,
"__nv_coshf",
433 populateOpPatterns<math::ErfOp>(converter, patterns,
"__nv_erff",
"__nv_erf");
434 populateOpPatterns<math::ExpOp>(converter, patterns,
"__nv_expf",
"__nv_exp",
436 populateOpPatterns<math::Exp2Op>(converter, patterns,
"__nv_exp2f",
438 populateOpPatterns<math::ExpM1Op>(converter, patterns,
"__nv_expm1f",
440 populateOpPatterns<math::FloorOp>(converter, patterns,
"__nv_floorf",
442 populateOpPatterns<math::FmaOp>(converter, patterns,
"__nv_fmaf",
"__nv_fma");
443 populateOpPatterns<math::LogOp>(converter, patterns,
"__nv_logf",
"__nv_log",
445 populateOpPatterns<math::Log10Op>(converter, patterns,
"__nv_log10f",
446 "__nv_log10",
"__nv_fast_log10f");
447 populateOpPatterns<math::Log1pOp>(converter, patterns,
"__nv_log1pf",
449 populateOpPatterns<math::Log2Op>(converter, patterns,
"__nv_log2f",
450 "__nv_log2",
"__nv_fast_log2f");
451 populateOpPatterns<math::PowFOp>(converter, patterns,
"__nv_powf",
"__nv_pow",
453 populateOpPatterns<math::RoundOp>(converter, patterns,
"__nv_roundf",
455 populateOpPatterns<math::RoundEvenOp>(converter, patterns,
"__nv_rintf",
457 populateOpPatterns<math::RsqrtOp>(converter, patterns,
"__nv_rsqrtf",
459 populateOpPatterns<math::SinOp>(converter, patterns,
"__nv_sinf",
"__nv_sin",
461 populateOpPatterns<math::SinhOp>(converter, patterns,
"__nv_sinhf",
463 populateOpPatterns<math::SqrtOp>(converter, patterns,
"__nv_sqrtf",
465 populateOpPatterns<math::TanOp>(converter, patterns,
"__nv_tanf",
"__nv_tan",
467 populateOpPatterns<math::TanhOp>(converter, patterns,
"__nv_tanhf",
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
@ kGlobalMemorySpace
Global memory space identifier.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func depen...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.