42#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
43#include "mlir/Conversion/Passes.h.inc"
51static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
53 case gpu::ShuffleMode::XOR:
54 return NVVM::ShflKind::bfly;
55 case gpu::ShuffleMode::UP:
56 return NVVM::ShflKind::up;
57 case gpu::ShuffleMode::DOWN:
58 return NVVM::ShflKind::down;
59 case gpu::ShuffleMode::IDX:
60 return NVVM::ShflKind::idx;
62 llvm_unreachable(
"unknown shuffle mode");
65static std::optional<NVVM::ReductionKind>
66convertToNVVMReductionKind(gpu::AllReduceOperation mode) {
68 case gpu::AllReduceOperation::ADD:
69 return NVVM::ReductionKind::ADD;
70 case gpu::AllReduceOperation::MUL:
72 case gpu::AllReduceOperation::MINSI:
73 return NVVM::ReductionKind::MIN;
74 case gpu::AllReduceOperation::MINUI:
76 case gpu::AllReduceOperation::MINNUMF:
77 return NVVM::ReductionKind::MIN;
78 case gpu::AllReduceOperation::MAXSI:
79 return NVVM::ReductionKind::MAX;
80 case gpu::AllReduceOperation::MAXUI:
82 case gpu::AllReduceOperation::MAXNUMF:
83 return NVVM::ReductionKind::MAX;
84 case gpu::AllReduceOperation::AND:
85 return NVVM::ReductionKind::AND;
86 case gpu::AllReduceOperation::OR:
87 return NVVM::ReductionKind::OR;
88 case gpu::AllReduceOperation::XOR:
89 return NVVM::ReductionKind::XOR;
90 case gpu::AllReduceOperation::MINIMUMF:
91 case gpu::AllReduceOperation::MAXIMUMF:
99struct GPUSubgroupReduceOpLowering
101 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
104 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter)
const override {
106 if (op.getClusterSize())
107 return rewriter.notifyMatchFailure(
108 op,
"lowering for clustered reduce not implemented");
110 if (!op.getUniform())
111 return rewriter.notifyMatchFailure(
112 op,
"cannot be lowered to redux as the op must be run "
113 "uniformly (entire subgroup).");
114 if (!op.getValue().getType().isInteger(32))
115 return rewriter.notifyMatchFailure(op,
"unsupported data type");
117 std::optional<NVVM::ReductionKind> mode =
118 convertToNVVMReductionKind(op.getOp());
119 if (!mode.has_value())
120 return rewriter.notifyMatchFailure(
121 op,
"unsupported reduction mode for redux");
123 Location loc = op->getLoc();
124 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
125 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
127 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
128 op.getValue(), mode.value(), offset);
130 rewriter.replaceOp(op, reduxOp->getResult(0));
136 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
157 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter)
const override {
159 Location loc = op->getLoc();
161 auto valueTy = adaptor.getValue().getType();
162 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
163 auto predTy = IntegerType::get(rewriter.getContext(), 1);
165 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
166 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
167 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
168 Value numLeadInactiveLane = LLVM::SubOp::create(
169 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
171 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
172 numLeadInactiveLane);
174 if (op.getMode() == gpu::ShuffleMode::UP) {
176 maskAndClamp = numLeadInactiveLane;
179 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
180 adaptor.getWidth(), one);
183 bool predIsUsed = !op->getResult(1).use_empty();
184 UnitAttr returnValueAndIsValidAttr =
nullptr;
185 Type resultTy = valueTy;
187 returnValueAndIsValidAttr = rewriter.getUnitAttr();
188 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
191 Value shfl = NVVM::ShflOp::create(
192 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
193 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
194 returnValueAndIsValidAttr);
196 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
197 Value isActiveSrcLane =
198 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
199 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
201 rewriter.replaceOp(op, {shfl,
nullptr});
208 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
211 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
212 ConversionPatternRewriter &rewriter)
const override {
213 auto loc = op->getLoc();
214 MLIRContext *context = rewriter.getContext();
215 LLVM::ConstantRangeAttr bounds =
nullptr;
216 if (std::optional<APInt> upperBound = op.getUpperBound())
217 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
218 32, 0, upperBound->getZExtValue());
220 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
223 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
226 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
227 if (indexBitwidth > 32) {
228 newOp = LLVM::SExtOp::create(
229 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
230 }
else if (indexBitwidth < 32) {
231 newOp = LLVM::TruncOp::create(
232 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
234 rewriter.replaceOp(op, {newOp});
240 using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
243 matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
244 ConversionPatternRewriter &rewriter)
const override {
245 Location loc = op->getLoc();
246 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
247 auto intType = cast<IntegerType>(op.getType());
248 unsigned width = intType.getWidth();
252 if (width != 32 && width != 64)
253 return rewriter.notifyMatchFailure(
254 op,
"nvvm.vote.sync ballot only supports i32 and i64 result types");
257 Value mask = LLVM::ConstantOp::create(rewriter, loc, int32Type,
258 rewriter.getI32IntegerAttr(-1));
260 auto voteKind = NVVM::VoteSyncKindAttr::get(rewriter.getContext(),
261 NVVM::VoteSyncKind::ballot);
262 Value
result = NVVM::VoteSyncOp::create(rewriter, loc, int32Type, mask,
263 adaptor.getPredicate(), voteKind);
266 result = LLVM::ZExtOp::create(rewriter, loc, op.getType(),
result);
268 rewriter.replaceOp(op,
result);
274struct AssertOpToAssertfailLowering
276 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
279 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
280 ConversionPatternRewriter &rewriter)
const override {
281 MLIRContext *ctx = rewriter.getContext();
282 Location loc = assertOp.getLoc();
283 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
284 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
285 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
286 Type ptrType = LLVM::LLVMPointerType::get(ctx);
287 Type voidType = LLVM::LLVMVoidType::get(ctx);
290 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
291 auto assertfailType = LLVM::LLVMFunctionType::get(
292 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
294 moduleOp, loc, rewriter,
"__assertfail", assertfailType);
295 assertfailDecl.setPassthroughAttr(
296 ArrayAttr::get(ctx, StringAttr::get(ctx,
"noreturn")));
307 Block *beforeBlock = assertOp->getBlock();
309 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
311 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
312 rewriter.setInsertionPointToEnd(beforeBlock);
313 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
315 rewriter.setInsertionPointToEnd(assertBlock);
316 cf::BranchOp::create(rewriter, loc, afterBlock);
319 rewriter.setInsertionPoint(assertOp);
323 StringRef fileName =
"(unknown)";
324 StringRef funcName =
"(unknown)";
325 int32_t fileLine = 0;
326 while (
auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
327 loc = callSiteLoc.getCallee();
328 if (
auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
329 fileName = fileLineColLoc.getFilename().strref();
330 fileLine = fileLineColLoc.getStartLine();
331 }
else if (
auto nameLoc = dyn_cast<NameLoc>(loc)) {
332 funcName = nameLoc.getName().strref();
333 if (
auto fileLineColLoc =
334 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
335 fileName = fileLineColLoc.getFilename().strref();
336 fileLine = fileLineColLoc.getStartLine();
341 auto getGlobal = [&](LLVM::GlobalOp global) {
343 Value globalPtr = LLVM::AddressOfOp::create(
344 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
345 global.getSymNameAttr());
347 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
348 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
352 rewriter, loc, moduleOp, i8Type,
"assert_message_", assertOp.getMsg()));
354 rewriter, loc, moduleOp, i8Type,
"assert_file_", fileName));
356 rewriter, loc, moduleOp, i8Type,
"assert_func_", funcName));
358 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
359 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
362 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
364 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
371#include "GPUToNVVM.cpp.inc"
378struct LowerGpuOpsToNVVMOpsPass final
379 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
382 void getDependentDialects(DialectRegistry ®istry)
const override {
383 Base::getDependentDialects(registry);
387 void runOnOperation()
override {
388 gpu::GPUModuleOp m = getOperation();
391 for (
auto func : m.getOps<func::FuncOp>()) {
392 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
399 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
401 options.overrideIndexBitwidth(indexBitwidth);
402 options.useBarePtrCallConv = useBarePtrCallConv;
408 RewritePatternSet patterns(m.getContext());
412 vector::populateVectorFromElementsUnrollPatterns(patterns);
414 return signalPassFailure();
417 LLVMTypeConverter converter(m.getContext(),
options);
419 RewritePatternSet llvmPatterns(m.getContext());
426 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
427 allowedDialects.end());
428 for (Dialect *dialect :
getContext().getLoadedDialects()) {
430 if (isa<math::MathDialect>(dialect))
433 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
435 if (!allowedDialectsSet.empty() && !allowed)
438 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
444 <<
"dialect does not implement ConvertToLLVMPatternInterface: "
445 << dialect->getNamespace();
446 return signalPassFailure();
451 iface->populateConvertToLLVMConversionPatterns(
target, converter,
459 ConversionConfig config;
460 config.allowPatternRollback = allowPatternRollback;
462 applyPartialConversion(m,
target, std::move(llvmPatterns), config)))
470 target.addIllegalOp<func::FuncOp>();
471 target.addIllegalOp<cf::AssertOp>();
472 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
473 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
474 target.addIllegalDialect<gpu::GPUDialect>();
475 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
476 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
477 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
478 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
479 LLVM::SincosOp, LLVM::SqrtOp>();
482 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
497 patterns.
add<GPUSubgroupReduceOpLowering>(converter, benefit);
507 populateWithGenerated(patterns);
513 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
514 converter, IndexKind::Block, IntrType::Id, benefit);
517 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
518 converter, IndexKind::Block, IntrType::Dim, benefit);
521 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
522 converter, IndexKind::Other, IntrType::Id, benefit);
524 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
525 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
528 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
529 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
530 converter, IndexKind::Cluster, IntrType::Id, benefit);
532 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
533 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
534 converter, IndexKind::Cluster, IntrType::Dim, benefit);
536 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
537 converter, IndexKind::Grid, IntrType::Id, benefit);
539 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
540 converter, IndexKind::Grid, IntrType::Dim, benefit);
541 patterns.
add<GPULaneIdOpToNVVM, GPUBallotOpToNVVM, GPUShuffleOpLowering,
555 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
557 NVVM::NVVMDialect::getKernelFuncAttrName()),
559 NVVM::NVVMDialect::getMaxntidAttrName()),
561 NVVM::NVVMDialect::getClusterDimAttrName())},
572struct NVVMTargetConvertToLLVMAttrInterface
573 :
public ConvertToLLVMAttrInterface::ExternalModel<
574 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
576 void populateConvertToLLVMConversionPatterns(
582void NVVMTargetConvertToLLVMAttrInterface::
583 populateConvertToLLVMConversionPatterns(
Attribute attr,
594 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
void registerConvertGpuToNVVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
gpu::DimensionKind IndexKind
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to NVVM libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.