42#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
43#include "mlir/Conversion/Passes.h.inc"
51static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
53 case gpu::ShuffleMode::XOR:
54 return NVVM::ShflKind::bfly;
55 case gpu::ShuffleMode::UP:
56 return NVVM::ShflKind::up;
57 case gpu::ShuffleMode::DOWN:
58 return NVVM::ShflKind::down;
59 case gpu::ShuffleMode::IDX:
60 return NVVM::ShflKind::idx;
62 llvm_unreachable(
"unknown shuffle mode");
65static std::optional<NVVM::ReductionKind>
66convertToNVVMReductionKind(gpu::AllReduceOperation mode) {
68 case gpu::AllReduceOperation::ADD:
69 return NVVM::ReductionKind::ADD;
70 case gpu::AllReduceOperation::MUL:
72 case gpu::AllReduceOperation::MINSI:
73 return NVVM::ReductionKind::MIN;
74 case gpu::AllReduceOperation::MINUI:
76 case gpu::AllReduceOperation::MINNUMF:
77 return NVVM::ReductionKind::MIN;
78 case gpu::AllReduceOperation::MAXSI:
79 return NVVM::ReductionKind::MAX;
80 case gpu::AllReduceOperation::MAXUI:
82 case gpu::AllReduceOperation::MAXNUMF:
83 return NVVM::ReductionKind::MAX;
84 case gpu::AllReduceOperation::AND:
85 return NVVM::ReductionKind::AND;
86 case gpu::AllReduceOperation::OR:
87 return NVVM::ReductionKind::OR;
88 case gpu::AllReduceOperation::XOR:
89 return NVVM::ReductionKind::XOR;
90 case gpu::AllReduceOperation::MINIMUMF:
91 case gpu::AllReduceOperation::MAXIMUMF:
99struct GPUSubgroupReduceOpLowering
101 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
104 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter)
const override {
106 if (op.getClusterSize())
107 return rewriter.notifyMatchFailure(
108 op,
"lowering for clustered reduce not implemented");
110 if (!op.getUniform())
111 return rewriter.notifyMatchFailure(
112 op,
"cannot be lowered to redux as the op must be run "
113 "uniformly (entire subgroup).");
114 if (!op.getValue().getType().isInteger(32))
115 return rewriter.notifyMatchFailure(op,
"unsupported data type");
117 std::optional<NVVM::ReductionKind> mode =
118 convertToNVVMReductionKind(op.getOp());
119 if (!mode.has_value())
120 return rewriter.notifyMatchFailure(
121 op,
"unsupported reduction mode for redux");
123 Location loc = op->getLoc();
124 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
125 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
127 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
128 op.getValue(), mode.value(), offset);
130 rewriter.replaceOp(op, reduxOp->getResult(0));
136 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
157 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter)
const override {
159 Location loc = op->getLoc();
161 auto valueTy = adaptor.getValue().getType();
162 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
163 auto predTy = IntegerType::get(rewriter.getContext(), 1);
165 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
166 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
167 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
168 Value numLeadInactiveLane = LLVM::SubOp::create(
169 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
171 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
172 numLeadInactiveLane);
174 if (op.getMode() == gpu::ShuffleMode::UP) {
176 maskAndClamp = numLeadInactiveLane;
179 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
180 adaptor.getWidth(), one);
183 bool predIsUsed = !op->getResult(1).use_empty();
184 UnitAttr returnValueAndIsValidAttr =
nullptr;
185 Type resultTy = valueTy;
187 returnValueAndIsValidAttr = rewriter.getUnitAttr();
188 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
191 Value shfl = NVVM::ShflOp::create(
192 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
193 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
194 returnValueAndIsValidAttr);
196 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
197 Value isActiveSrcLane =
198 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
199 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
201 rewriter.replaceOp(op, {shfl,
nullptr});
208 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
211 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
212 ConversionPatternRewriter &rewriter)
const override {
213 auto loc = op->getLoc();
214 MLIRContext *context = rewriter.getContext();
215 LLVM::ConstantRangeAttr bounds =
nullptr;
216 if (std::optional<APInt> upperBound = op.getUpperBound())
217 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
218 32, 0, upperBound->getZExtValue());
220 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
223 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
226 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
227 if (indexBitwidth > 32) {
228 newOp = LLVM::SExtOp::create(
229 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
230 }
else if (indexBitwidth < 32) {
231 newOp = LLVM::TruncOp::create(
232 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
234 rewriter.replaceOp(op, {newOp});
240struct AssertOpToAssertfailLowering
242 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
245 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 MLIRContext *ctx = rewriter.getContext();
248 Location loc = assertOp.getLoc();
249 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
250 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
251 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
252 Type ptrType = LLVM::LLVMPointerType::get(ctx);
253 Type voidType = LLVM::LLVMVoidType::get(ctx);
256 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
257 auto assertfailType = LLVM::LLVMFunctionType::get(
258 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
260 moduleOp, loc, rewriter,
"__assertfail", assertfailType);
261 assertfailDecl.setPassthroughAttr(
262 ArrayAttr::get(ctx, StringAttr::get(ctx,
"noreturn")));
273 Block *beforeBlock = assertOp->getBlock();
275 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
277 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
278 rewriter.setInsertionPointToEnd(beforeBlock);
279 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
281 rewriter.setInsertionPointToEnd(assertBlock);
282 cf::BranchOp::create(rewriter, loc, afterBlock);
285 rewriter.setInsertionPoint(assertOp);
289 StringRef fileName =
"(unknown)";
290 StringRef funcName =
"(unknown)";
291 int32_t fileLine = 0;
292 while (
auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
293 loc = callSiteLoc.getCallee();
294 if (
auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
295 fileName = fileLineColLoc.getFilename().strref();
296 fileLine = fileLineColLoc.getStartLine();
297 }
else if (
auto nameLoc = dyn_cast<NameLoc>(loc)) {
298 funcName = nameLoc.getName().strref();
299 if (
auto fileLineColLoc =
300 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
301 fileName = fileLineColLoc.getFilename().strref();
302 fileLine = fileLineColLoc.getStartLine();
307 auto getGlobal = [&](LLVM::GlobalOp global) {
309 Value globalPtr = LLVM::AddressOfOp::create(
310 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
311 global.getSymNameAttr());
313 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
314 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
318 rewriter, loc, moduleOp, i8Type,
"assert_message_", assertOp.getMsg()));
320 rewriter, loc, moduleOp, i8Type,
"assert_file_", fileName));
322 rewriter, loc, moduleOp, i8Type,
"assert_func_", funcName));
324 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
325 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
328 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
330 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
337#include "GPUToNVVM.cpp.inc"
344struct LowerGpuOpsToNVVMOpsPass final
345 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
348 void getDependentDialects(DialectRegistry ®istry)
const override {
349 Base::getDependentDialects(registry);
353 void runOnOperation()
override {
354 gpu::GPUModuleOp m = getOperation();
357 for (
auto func : m.getOps<func::FuncOp>()) {
358 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
365 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
367 options.overrideIndexBitwidth(indexBitwidth);
368 options.useBarePtrCallConv = useBarePtrCallConv;
374 RewritePatternSet patterns(m.getContext());
378 vector::populateVectorFromElementsUnrollPatterns(patterns);
380 return signalPassFailure();
383 LLVMTypeConverter converter(m.getContext(),
options);
385 RewritePatternSet llvmPatterns(m.getContext());
392 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
393 allowedDialects.end());
394 for (Dialect *dialect :
getContext().getLoadedDialects()) {
396 if (isa<math::MathDialect>(dialect))
399 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
401 if (!allowedDialectsSet.empty() && !allowed)
404 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
410 <<
"dialect does not implement ConvertToLLVMPatternInterface: "
411 << dialect->getNamespace();
412 return signalPassFailure();
417 iface->populateConvertToLLVMConversionPatterns(
target, converter,
425 ConversionConfig config;
426 config.allowPatternRollback = allowPatternRollback;
428 applyPartialConversion(m,
target, std::move(llvmPatterns), config)))
436 target.addIllegalOp<func::FuncOp>();
437 target.addIllegalOp<cf::AssertOp>();
438 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
439 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
440 target.addIllegalDialect<gpu::GPUDialect>();
441 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
442 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
443 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
444 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
445 LLVM::SincosOp, LLVM::SqrtOp>();
448 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
463 patterns.
add<GPUSubgroupReduceOpLowering>(converter, benefit);
473 populateWithGenerated(patterns);
479 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
480 converter, IndexKind::Block, IntrType::Id, benefit);
483 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
484 converter, IndexKind::Block, IntrType::Dim, benefit);
487 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
488 converter, IndexKind::Other, IntrType::Id, benefit);
490 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
491 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
494 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
495 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
496 converter, IndexKind::Cluster, IntrType::Id, benefit);
498 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
499 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
500 converter, IndexKind::Cluster, IntrType::Dim, benefit);
502 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
503 converter, IndexKind::Grid, IntrType::Id, benefit);
505 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
506 converter, IndexKind::Grid, IntrType::Dim, benefit);
521 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
523 NVVM::NVVMDialect::getKernelFuncAttrName()),
525 NVVM::NVVMDialect::getMaxntidAttrName()),
527 NVVM::NVVMDialect::getClusterDimAttrName())},
538struct NVVMTargetConvertToLLVMAttrInterface
539 :
public ConvertToLLVMAttrInterface::ExternalModel<
540 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
542 void populateConvertToLLVMConversionPatterns(
548void NVVMTargetConvertToLLVMAttrInterface::
549 populateConvertToLLVMConversionPatterns(
Attribute attr,
560 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
void registerConvertGpuToNVVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
gpu::DimensionKind IndexKind
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to NVVM libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.