39 #include "../GPUCommon/GPUOpsLowering.h"
40 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
41 #include "../GPUCommon/OpToFuncCallLowering.h"
45 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
46 #include "mlir/Conversion/Passes.h.inc"
54 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
56 case gpu::ShuffleMode::XOR:
57 return NVVM::ShflKind::bfly;
58 case gpu::ShuffleMode::UP:
59 return NVVM::ShflKind::up;
60 case gpu::ShuffleMode::DOWN:
61 return NVVM::ShflKind::down;
62 case gpu::ShuffleMode::IDX:
63 return NVVM::ShflKind::idx;
65 llvm_unreachable(
"unknown shuffle mode");
68 static std::optional<NVVM::ReduxKind>
69 convertReduxKind(gpu::AllReduceOperation mode) {
71 case gpu::AllReduceOperation::ADD:
72 return NVVM::ReduxKind::ADD;
73 case gpu::AllReduceOperation::MUL:
75 case gpu::AllReduceOperation::MINSI:
76 return NVVM::ReduxKind::MIN;
79 case gpu::AllReduceOperation::MINNUMF:
80 return NVVM::ReduxKind::MIN;
81 case gpu::AllReduceOperation::MAXSI:
82 return NVVM::ReduxKind::MAX;
83 case gpu::AllReduceOperation::MAXUI:
85 case gpu::AllReduceOperation::MAXNUMF:
86 return NVVM::ReduxKind::MAX;
87 case gpu::AllReduceOperation::AND:
88 return NVVM::ReduxKind::AND;
89 case gpu::AllReduceOperation::OR:
90 return NVVM::ReduxKind::OR;
91 case gpu::AllReduceOperation::XOR:
92 return NVVM::ReduxKind::XOR;
93 case gpu::AllReduceOperation::MINIMUMF:
94 case gpu::AllReduceOperation::MAXIMUMF:
102 struct GPUSubgroupReduceOpLowering
107 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
109 if (op.getClusterSize())
111 op,
"lowering for clustered reduce not implemented");
113 if (!op.getUniform())
115 op,
"cannot be lowered to redux as the op must be run "
116 "uniformly (entire subgroup).");
117 if (!op.getValue().getType().isInteger(32))
120 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
121 if (!mode.has_value())
123 op,
"unsupported reduction mode for redux");
127 Value offset = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
129 auto reduxOp = rewriter.
create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
130 mode.value(), offset);
132 rewriter.
replaceOp(op, reduxOp->getResult(0));
159 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
163 auto valueTy = adaptor.getValue().getType();
167 Value one = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 1);
168 Value minusOne = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, -1);
169 Value thirtyTwo = rewriter.
create<LLVM::ConstantOp>(loc, int32Type, 32);
170 Value numLeadInactiveLane = rewriter.
create<LLVM::SubOp>(
171 loc, int32Type, thirtyTwo, adaptor.getWidth());
173 Value activeMask = rewriter.
create<LLVM::LShrOp>(loc, int32Type, minusOne,
174 numLeadInactiveLane);
176 if (op.getMode() == gpu::ShuffleMode::UP) {
178 maskAndClamp = numLeadInactiveLane;
182 rewriter.
create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
186 UnitAttr returnValueAndIsValidAttr =
nullptr;
187 Type resultTy = valueTy;
189 returnValueAndIsValidAttr = rewriter.
getUnitAttr();
190 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.
getContext(),
194 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
195 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
197 Value shflValue = rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 0);
198 Value isActiveSrcLane =
199 rewriter.
create<LLVM::ExtractValueOp>(loc, shfl, 1);
200 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
212 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
214 auto loc = op->getLoc();
216 LLVM::ConstantRangeAttr bounds =
nullptr;
217 if (std::optional<APInt> upperBound = op.getUpperBound())
218 bounds = rewriter.
getAttr<LLVM::ConstantRangeAttr>(
219 32, 0, upperBound->getZExtValue());
221 bounds = rewriter.
getAttr<LLVM::ConstantRangeAttr>(
227 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
228 if (indexBitwidth > 32) {
229 newOp = rewriter.
create<LLVM::SExtOp>(
231 }
else if (indexBitwidth < 32) {
232 newOp = rewriter.
create<LLVM::TruncOp>(
241 struct AssertOpToAssertfailLowering
246 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
257 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
259 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
261 moduleOp, loc, rewriter,
"__assertfail", assertfailType);
262 assertfailDecl.setPassthroughAttr(
274 Block *beforeBlock = assertOp->getBlock();
276 rewriter.
splitBlock(beforeBlock, assertOp->getIterator());
278 rewriter.
splitBlock(assertBlock, ++assertOp->getIterator());
280 rewriter.
create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
283 rewriter.
create<cf::BranchOp>(loc, afterBlock);
290 StringRef fileName =
"(unknown)";
291 StringRef funcName =
"(unknown)";
292 int32_t fileLine = 0;
293 while (
auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
294 loc = callSiteLoc.getCallee();
295 if (
auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
296 fileName = fileLineColLoc.getFilename().strref();
297 fileLine = fileLineColLoc.getStartLine();
298 }
else if (
auto nameLoc = dyn_cast<NameLoc>(loc)) {
299 funcName = nameLoc.getName().strref();
300 if (
auto fileLineColLoc =
301 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
302 fileName = fileLineColLoc.getFilename().strref();
303 fileLine = fileLineColLoc.getStartLine();
308 auto getGlobal = [&](LLVM::GlobalOp global) {
310 Value globalPtr = rewriter.
create<LLVM::AddressOfOp>(
312 global.getSymNameAttr());
314 rewriter.
create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
319 rewriter, loc, moduleOp, i8Type,
"assert_message_", assertOp.getMsg()));
321 rewriter, loc, moduleOp, i8Type,
"assert_file_", fileName));
323 rewriter, loc, moduleOp, i8Type,
"assert_func_", funcName));
325 rewriter.
create<LLVM::ConstantOp>(loc, i32Type, fileLine);
326 Value c1 = rewriter.
create<LLVM::ConstantOp>(loc, i64Type, 1);
338 #include "GPUToNVVM.cpp.inc"
345 struct LowerGpuOpsToNVVMOpsPass
346 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
349 void runOnOperation()
override {
350 gpu::GPUModuleOp m = getOperation();
353 for (
auto func : m.getOps<func::FuncOp>()) {
354 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
361 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
363 options.overrideIndexBitwidth(indexBitwidth);
364 options.useBarePtrCallConv = useBarePtrCallConv;
373 return signalPassFailure();
403 target.
addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
404 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
405 LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
406 LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
407 LLVM::SinOp, LLVM::SqrtOp>();
410 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
419 converter, [](gpu::AddressSpace space) ->
unsigned {
421 case gpu::AddressSpace::Global:
422 return static_cast<unsigned>(
424 case gpu::AddressSpace::Workgroup:
425 return static_cast<unsigned>(
427 case gpu::AddressSpace::Private:
430 llvm_unreachable(
"unknown address space enum value");
439 template <
typename OpTy>
442 StringRef f64Func, StringRef f32ApproxFunc =
"",
443 StringRef f16Func =
"") {
446 f32ApproxFunc, f16Func);
451 patterns.add<GPUSubgroupReduceOpLowering>(converter);
463 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
464 converter, IndexKind::Block, IntrType::Id);
467 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
468 converter, IndexKind::Block, IntrType::Dim);
471 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
472 converter, IndexKind::Other, IntrType::Id);
474 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
475 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
477 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
478 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
479 converter, IndexKind::Other, IntrType::Id);
481 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
482 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
483 converter, IndexKind::Other, IntrType::Dim);
485 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
486 converter, IndexKind::Grid, IntrType::Id);
488 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
489 converter, IndexKind::Grid, IntrType::Dim);
506 NVVM::NVVMDialect::getKernelFuncAttrName()),
508 NVVM::NVVMDialect::getMaxntidAttrName())});
510 populateOpPatterns<arith::RemFOp>(converter,
patterns,
"__nv_fmodf",
512 populateOpPatterns<math::AbsFOp>(converter,
patterns,
"__nv_fabsf",
514 populateOpPatterns<math::AcosOp>(converter,
patterns,
"__nv_acosf",
516 populateOpPatterns<math::AcoshOp>(converter,
patterns,
"__nv_acoshf",
518 populateOpPatterns<math::AsinOp>(converter,
patterns,
"__nv_asinf",
520 populateOpPatterns<math::AsinhOp>(converter,
patterns,
"__nv_asinhf",
522 populateOpPatterns<math::AtanOp>(converter,
patterns,
"__nv_atanf",
524 populateOpPatterns<math::Atan2Op>(converter,
patterns,
"__nv_atan2f",
526 populateOpPatterns<math::AtanhOp>(converter,
patterns,
"__nv_atanhf",
528 populateOpPatterns<math::CbrtOp>(converter,
patterns,
"__nv_cbrtf",
530 populateOpPatterns<math::CeilOp>(converter,
patterns,
"__nv_ceilf",
532 populateOpPatterns<math::CopySignOp>(converter,
patterns,
"__nv_copysignf",
534 populateOpPatterns<math::CosOp>(converter,
patterns,
"__nv_cosf",
"__nv_cos",
536 populateOpPatterns<math::CoshOp>(converter,
patterns,
"__nv_coshf",
538 populateOpPatterns<math::ErfOp>(converter,
patterns,
"__nv_erff",
"__nv_erf");
539 populateOpPatterns<math::ExpOp>(converter,
patterns,
"__nv_expf",
"__nv_exp",
541 populateOpPatterns<math::Exp2Op>(converter,
patterns,
"__nv_exp2f",
543 populateOpPatterns<math::ExpM1Op>(converter,
patterns,
"__nv_expm1f",
545 populateOpPatterns<math::FloorOp>(converter,
patterns,
"__nv_floorf",
547 populateOpPatterns<math::FmaOp>(converter,
patterns,
"__nv_fmaf",
"__nv_fma");
548 populateOpPatterns<math::LogOp>(converter,
patterns,
"__nv_logf",
"__nv_log",
550 populateOpPatterns<math::Log10Op>(converter,
patterns,
"__nv_log10f",
551 "__nv_log10",
"__nv_fast_log10f");
552 populateOpPatterns<math::Log1pOp>(converter,
patterns,
"__nv_log1pf",
554 populateOpPatterns<math::Log2Op>(converter,
patterns,
"__nv_log2f",
555 "__nv_log2",
"__nv_fast_log2f");
556 populateOpPatterns<math::PowFOp>(converter,
patterns,
"__nv_powf",
"__nv_pow",
558 populateOpPatterns<math::RoundOp>(converter,
patterns,
"__nv_roundf",
560 populateOpPatterns<math::RoundEvenOp>(converter,
patterns,
"__nv_rintf",
562 populateOpPatterns<math::RsqrtOp>(converter,
patterns,
"__nv_rsqrtf",
564 populateOpPatterns<math::SinOp>(converter,
patterns,
"__nv_sinf",
"__nv_sin",
566 populateOpPatterns<math::SinhOp>(converter,
patterns,
"__nv_sinhf",
568 populateOpPatterns<math::SqrtOp>(converter,
patterns,
"__nv_sqrtf",
570 populateOpPatterns<math::TanOp>(converter,
patterns,
"__nv_tanf",
"__nv_tan",
572 populateOpPatterns<math::TanhOp>(converter,
patterns,
"__nv_tanhf",
581 struct NVVMTargetConvertToLLVMAttrInterface
582 :
public ConvertToLLVMAttrInterface::ExternalModel<
583 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
585 void populateConvertToLLVMConversionPatterns(
591 void NVVMTargetConvertToLLVMAttrInterface::
592 populateConvertToLLVMConversionPatterns(
Attribute attr,
603 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
@ kGlobalMemorySpace
Global memory space identifier.
void registerConvertGpuToNVVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replace SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func depen...
Rewriting that unrolls SourceOp to scalars if it's operating on vectors.