38 #include "../GPUCommon/GPUOpsLowering.h"
39 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
40 #include "../GPUCommon/OpToFuncCallLowering.h"
44 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
45 #include "mlir/Conversion/Passes.h.inc"
53 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
55 case gpu::ShuffleMode::XOR:
56 return NVVM::ShflKind::bfly;
57 case gpu::ShuffleMode::UP:
58 return NVVM::ShflKind::up;
59 case gpu::ShuffleMode::DOWN:
60 return NVVM::ShflKind::down;
61 case gpu::ShuffleMode::IDX:
62 return NVVM::ShflKind::idx;
64 llvm_unreachable(
"unknown shuffle mode");
67 static std::optional<NVVM::ReduxKind>
68 convertReduxKind(gpu::AllReduceOperation mode) {
70 case gpu::AllReduceOperation::ADD:
71 return NVVM::ReduxKind::ADD;
72 case gpu::AllReduceOperation::MUL:
74 case gpu::AllReduceOperation::MINSI:
75 return NVVM::ReduxKind::MIN;
78 case gpu::AllReduceOperation::MINNUMF:
79 return NVVM::ReduxKind::MIN;
80 case gpu::AllReduceOperation::MAXSI:
81 return NVVM::ReduxKind::MAX;
82 case gpu::AllReduceOperation::MAXUI:
84 case gpu::AllReduceOperation::MAXNUMF:
85 return NVVM::ReduxKind::MAX;
86 case gpu::AllReduceOperation::AND:
87 return NVVM::ReduxKind::AND;
88 case gpu::AllReduceOperation::OR:
89 return NVVM::ReduxKind::OR;
90 case gpu::AllReduceOperation::XOR:
91 return NVVM::ReduxKind::XOR;
92 case gpu::AllReduceOperation::MINIMUMF:
93 case gpu::AllReduceOperation::MAXIMUMF:
101 struct GPUSubgroupReduceOpLowering
106 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
108 if (op.getClusterSize())
110 op,
"lowering for clustered reduce not implemented");
112 if (!op.getUniform())
114 op,
"cannot be lowered to redux as the op must be run "
115 "uniformly (entire subgroup).");
116 if (!op.getValue().getType().isInteger(32))
119 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
120 if (!mode.has_value())
122 op,
"unsupported reduction mode for redux");
126 Value offset = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
128 auto reduxOp = rewriter.
create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
129 mode.value(), offset);
131 rewriter.
replaceOp(op, reduxOp->getResult(0));
158 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
162 auto valueTy = adaptor.getValue().getType();
166 Value one = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 1);
167 Value minusOne = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
168 Value thirtyTwo = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 32);
169 Value numLeadInactiveLane = rewriter.
create<LLVM::SubOp>(
170 loc, int32Type, thirtyTwo, adaptor.getWidth());
172 Value activeMask = rewriter.
create<LLVM::LShrOp>(loc, int32Type, minusOne,
173 numLeadInactiveLane);
175 if (op.getMode() == gpu::ShuffleMode::UP) {
177 maskAndClamp = numLeadInactiveLane;
181 rewriter.
create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
185 UnitAttr returnValueAndIsValidAttr =
nullptr;
186 Type resultTy = valueTy;
188 returnValueAndIsValidAttr = rewriter.
getUnitAttr();
189 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.
getContext(),
193 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
194 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
196 Value shflValue = rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 0);
197 Value isActiveSrcLane =
198 rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 1);
199 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
211 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
213 auto loc = op->getLoc();
215 LLVM::ConstantRangeAttr bounds =
nullptr;
216 if (std::optional<APInt> upperBound = op.getUpperBound())
217 bounds = rewriter.
getAttr<LLVM::ConstantRangeAttr>(
218 32, 0, upperBound->getZExtValue());
220 bounds = rewriter.
getAttr<LLVM::ConstantRangeAttr>(
226 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
227 if (indexBitwidth > 32) {
228 newOp = rewriter.
create<LLVM::SExtOp>(
230 }
else if (indexBitwidth < 32) {
231 newOp = rewriter.
create<LLVM::TruncOp>(
240 #include "GPUToNVVM.cpp.inc"
247 struct LowerGpuOpsToNVVMOpsPass
248 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
251 void runOnOperation()
override {
252 gpu::GPUModuleOp m = getOperation();
255 for (
auto func : m.getOps<func::FuncOp>()) {
256 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
263 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
265 options.overrideIndexBitwidth(indexBitwidth);
266 options.useBarePtrCallConv = useBarePtrCallConv;
275 return signalPassFailure();
305 target.
addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
306 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
307 LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
308 LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
309 LLVM::SinOp, LLVM::SqrtOp>();
312 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
321 converter, [](gpu::AddressSpace space) ->
unsigned {
323 case gpu::AddressSpace::Global:
324 return static_cast<unsigned>(
326 case gpu::AddressSpace::Workgroup:
327 return static_cast<unsigned>(
329 case gpu::AddressSpace::Private:
332 llvm_unreachable(
"unknown address space enum value");
341 template <
typename OpTy>
344 StringRef f64Func, StringRef f32ApproxFunc =
"",
345 StringRef f16Func =
"") {
348 f32ApproxFunc, f16Func);
353 patterns.
add<GPUSubgroupReduceOpLowering>(converter);
360 populateWithGenerated(patterns);
364 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
365 converter, IndexKind::Block, IntrType::Id);
368 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
369 converter, IndexKind::Block, IntrType::Dim);
372 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
373 converter, IndexKind::Other, IntrType::Id);
375 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
376 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
378 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
379 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
380 converter, IndexKind::Other, IntrType::Id);
382 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
383 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
384 converter, IndexKind::Other, IntrType::Dim);
386 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
387 converter, IndexKind::Grid, IntrType::Id);
389 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
390 converter, IndexKind::Grid, IntrType::Dim);
407 NVVM::NVVMDialect::getKernelFuncAttrName()),
409 NVVM::NVVMDialect::getMaxntidAttrName())});
411 populateOpPatterns<arith::RemFOp>(converter, patterns,
"__nv_fmodf",
413 populateOpPatterns<math::AbsFOp>(converter, patterns,
"__nv_fabsf",
415 populateOpPatterns<math::AcosOp>(converter, patterns,
"__nv_acosf",
417 populateOpPatterns<math::AcoshOp>(converter, patterns,
"__nv_acoshf",
419 populateOpPatterns<math::AsinOp>(converter, patterns,
"__nv_asinf",
421 populateOpPatterns<math::AsinhOp>(converter, patterns,
"__nv_asinhf",
423 populateOpPatterns<math::AtanOp>(converter, patterns,
"__nv_atanf",
425 populateOpPatterns<math::Atan2Op>(converter, patterns,
"__nv_atan2f",
427 populateOpPatterns<math::AtanhOp>(converter, patterns,
"__nv_atanhf",
429 populateOpPatterns<math::CbrtOp>(converter, patterns,
"__nv_cbrtf",
431 populateOpPatterns<math::CeilOp>(converter, patterns,
"__nv_ceilf",
433 populateOpPatterns<math::CopySignOp>(converter, patterns,
"__nv_copysignf",
435 populateOpPatterns<math::CosOp>(converter, patterns,
"__nv_cosf",
"__nv_cos",
437 populateOpPatterns<math::CoshOp>(converter, patterns,
"__nv_coshf",
439 populateOpPatterns<math::ErfOp>(converter, patterns,
"__nv_erff",
"__nv_erf");
440 populateOpPatterns<math::ExpOp>(converter, patterns,
"__nv_expf",
"__nv_exp",
442 populateOpPatterns<math::Exp2Op>(converter, patterns,
"__nv_exp2f",
444 populateOpPatterns<math::ExpM1Op>(converter, patterns,
"__nv_expm1f",
446 populateOpPatterns<math::FloorOp>(converter, patterns,
"__nv_floorf",
448 populateOpPatterns<math::FmaOp>(converter, patterns,
"__nv_fmaf",
"__nv_fma");
449 populateOpPatterns<math::LogOp>(converter, patterns,
"__nv_logf",
"__nv_log",
451 populateOpPatterns<math::Log10Op>(converter, patterns,
"__nv_log10f",
452 "__nv_log10",
"__nv_fast_log10f");
453 populateOpPatterns<math::Log1pOp>(converter, patterns,
"__nv_log1pf",
455 populateOpPatterns<math::Log2Op>(converter, patterns,
"__nv_log2f",
456 "__nv_log2",
"__nv_fast_log2f");
457 populateOpPatterns<math::PowFOp>(converter, patterns,
"__nv_powf",
"__nv_pow",
459 populateOpPatterns<math::RoundOp>(converter, patterns,
"__nv_roundf",
461 populateOpPatterns<math::RoundEvenOp>(converter, patterns,
"__nv_rintf",
463 populateOpPatterns<math::RsqrtOp>(converter, patterns,
"__nv_rsqrtf",
465 populateOpPatterns<math::SinOp>(converter, patterns,
"__nv_sinf",
"__nv_sin",
467 populateOpPatterns<math::SinhOp>(converter, patterns,
"__nv_sinhf",
469 populateOpPatterns<math::SqrtOp>(converter, patterns,
"__nv_sqrtf",
471 populateOpPatterns<math::TanOp>(converter, patterns,
"__nv_tanf",
"__nv_tan",
473 populateOpPatterns<math::TanhOp>(converter, patterns,
"__nv_tanhf",
482 struct NVVMTargetConvertToLLVMAttrInterface
483 :
public ConvertToLLVMAttrInterface::ExternalModel<
484 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
486 void populateConvertToLLVMConversionPatterns(
492 void NVVMTargetConvertToLLVMAttrInterface::
493 populateConvertToLLVMConversionPatterns(
Attribute attr,
504 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
@ kGlobalMemorySpace
Global memory space identifier.
void registerConvertGpuToNVVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func depen...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.