MLIR 23.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
35
39#include <optional>
40
41namespace mlir {
42#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
43#include "mlir/Conversion/Passes.h.inc"
44} // namespace mlir
45
46using namespace mlir;
47
48namespace {
49
50/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
51static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
52 switch (mode) {
53 case gpu::ShuffleMode::XOR:
54 return NVVM::ShflKind::bfly;
55 case gpu::ShuffleMode::UP:
56 return NVVM::ShflKind::up;
57 case gpu::ShuffleMode::DOWN:
58 return NVVM::ShflKind::down;
59 case gpu::ShuffleMode::IDX:
60 return NVVM::ShflKind::idx;
61 }
62 llvm_unreachable("unknown shuffle mode");
63}
64
65static std::optional<NVVM::ReductionKind>
66convertToNVVMReductionKind(gpu::AllReduceOperation mode) {
67 switch (mode) {
68 case gpu::AllReduceOperation::ADD:
69 return NVVM::ReductionKind::ADD;
70 case gpu::AllReduceOperation::MUL:
71 return std::nullopt;
72 case gpu::AllReduceOperation::MINSI:
73 return NVVM::ReductionKind::MIN;
74 case gpu::AllReduceOperation::MINUI:
75 return std::nullopt;
76 case gpu::AllReduceOperation::MINNUMF:
77 return NVVM::ReductionKind::MIN;
78 case gpu::AllReduceOperation::MAXSI:
79 return NVVM::ReductionKind::MAX;
80 case gpu::AllReduceOperation::MAXUI:
81 return std::nullopt;
82 case gpu::AllReduceOperation::MAXNUMF:
83 return NVVM::ReductionKind::MAX;
84 case gpu::AllReduceOperation::AND:
85 return NVVM::ReductionKind::AND;
86 case gpu::AllReduceOperation::OR:
87 return NVVM::ReductionKind::OR;
88 case gpu::AllReduceOperation::XOR:
89 return NVVM::ReductionKind::XOR;
90 case gpu::AllReduceOperation::MINIMUMF:
91 case gpu::AllReduceOperation::MAXIMUMF:
92 return std::nullopt;
93 }
94 return std::nullopt;
95}
96
97/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
98/// must be run by the entire subgroup, otherwise it is undefined behaviour.
99struct GPUSubgroupReduceOpLowering
100 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
101 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
102 LogicalResult
103
104 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter) const override {
106 if (op.getClusterSize())
107 return rewriter.notifyMatchFailure(
108 op, "lowering for clustered reduce not implemented");
109
110 if (!op.getUniform())
111 return rewriter.notifyMatchFailure(
112 op, "cannot be lowered to redux as the op must be run "
113 "uniformly (entire subgroup).");
114 if (!op.getValue().getType().isInteger(32))
115 return rewriter.notifyMatchFailure(op, "unsupported data type");
116
117 std::optional<NVVM::ReductionKind> mode =
118 convertToNVVMReductionKind(op.getOp());
119 if (!mode.has_value())
120 return rewriter.notifyMatchFailure(
121 op, "unsupported reduction mode for redux");
122
123 Location loc = op->getLoc();
124 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
125 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
126
127 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
128 op.getValue(), mode.value(), offset);
129
130 rewriter.replaceOp(op, reduxOp->getResult(0));
131 return success();
132 }
133};
134
135struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
136 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
137
138 /// Lowers a shuffle to the corresponding NVVM op.
139 ///
140 /// Convert the `width` argument into an activeMask (a bitmask which specifies
141 /// which threads participate in the shuffle) and a maskAndClamp (specifying
142 /// the highest lane which participates in the shuffle).
143 ///
144 /// %one = llvm.constant(1 : i32) : i32
145 /// %minus_one = llvm.constant(-1 : i32) : i32
146 /// %thirty_two = llvm.constant(32 : i32) : i32
147 /// %num_lanes = llvm.sub %thirty_two, %width : i32
148 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
149 /// %mask_and_clamp = llvm.sub %width, %one : i32
150 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
151 /// %mask_and_clamp : !llvm<"{ float, i1 }">
152 /// %shfl_value = llvm.extractvalue %shfl[0] :
153 /// !llvm<"{ float, i1 }">
154 /// %shfl_pred = llvm.extractvalue %shfl[1] :
155 /// !llvm<"{ float, i1 }">
156 LogicalResult
157 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter) const override {
159 Location loc = op->getLoc();
160
161 auto valueTy = adaptor.getValue().getType();
162 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
163 auto predTy = IntegerType::get(rewriter.getContext(), 1);
164
165 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
166 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
167 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
168 Value numLeadInactiveLane = LLVM::SubOp::create(
169 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
170 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
171 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
172 numLeadInactiveLane);
173 Value maskAndClamp;
174 if (op.getMode() == gpu::ShuffleMode::UP) {
175 // Clamp lane: `32 - activeWidth`
176 maskAndClamp = numLeadInactiveLane;
177 } else {
178 // Clamp lane: `activeWidth - 1`
179 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
180 adaptor.getWidth(), one);
181 }
182
183 bool predIsUsed = !op->getResult(1).use_empty();
184 UnitAttr returnValueAndIsValidAttr = nullptr;
185 Type resultTy = valueTy;
186 if (predIsUsed) {
187 returnValueAndIsValidAttr = rewriter.getUnitAttr();
188 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
189 {valueTy, predTy});
190 }
191 Value shfl = NVVM::ShflOp::create(
192 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
193 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
194 returnValueAndIsValidAttr);
195 if (predIsUsed) {
196 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
197 Value isActiveSrcLane =
198 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
199 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
200 } else {
201 rewriter.replaceOp(op, {shfl, nullptr});
202 }
203 return success();
204 }
205};
206
207struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
208 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
209
210 LogicalResult
211 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
212 ConversionPatternRewriter &rewriter) const override {
213 auto loc = op->getLoc();
214 MLIRContext *context = rewriter.getContext();
215 LLVM::ConstantRangeAttr bounds = nullptr;
216 if (std::optional<APInt> upperBound = op.getUpperBound())
217 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
218 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
219 else
220 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
221 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
222 Value newOp =
223 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
224 // Truncate or extend the result depending on the index bitwidth specified
225 // by the LLVMTypeConverter options.
226 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
227 if (indexBitwidth > 32) {
228 newOp = LLVM::SExtOp::create(
229 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
230 } else if (indexBitwidth < 32) {
231 newOp = LLVM::TruncOp::create(
232 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
233 }
234 rewriter.replaceOp(op, {newOp});
235 return success();
236 }
237};
238
239struct GPUBallotOpToNVVM : public ConvertOpToLLVMPattern<gpu::BallotOp> {
240 using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
241
242 LogicalResult
243 matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
244 ConversionPatternRewriter &rewriter) const override {
245 Location loc = op->getLoc();
246 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
247 auto intType = cast<IntegerType>(op.getType());
248 unsigned width = intType.getWidth();
249
250 // NVVM ballot natively returns i32. For i64 results, zero-extend since
251 // NVIDIA warps have exactly 32 threads, so upper 32 bits are always zero.
252 if (width != 32 && width != 64)
253 return rewriter.notifyMatchFailure(
254 op, "nvvm.vote.sync ballot only supports i32 and i64 result types");
255
256 // Use full mask (-1) so all 32 lanes participate in the ballot.
257 Value mask = LLVM::ConstantOp::create(rewriter, loc, int32Type,
258 rewriter.getI32IntegerAttr(-1));
259
260 auto voteKind = NVVM::VoteSyncKindAttr::get(rewriter.getContext(),
261 NVVM::VoteSyncKind::ballot);
262 Value result = NVVM::VoteSyncOp::create(rewriter, loc, int32Type, mask,
263 adaptor.getPredicate(), voteKind);
264
265 if (width == 64)
266 result = LLVM::ZExtOp::create(rewriter, loc, op.getType(), result);
267
268 rewriter.replaceOp(op, result);
269 return success();
270 }
271};
272
273/// Lowering of cf.assert into a conditional __assertfail.
274struct AssertOpToAssertfailLowering
275 : public ConvertOpToLLVMPattern<cf::AssertOp> {
276 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
277
278 LogicalResult
279 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 MLIRContext *ctx = rewriter.getContext();
282 Location loc = assertOp.getLoc();
283 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
284 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
285 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
286 Type ptrType = LLVM::LLVMPointerType::get(ctx);
287 Type voidType = LLVM::LLVMVoidType::get(ctx);
288
289 // Find or create __assertfail function declaration.
290 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
291 auto assertfailType = LLVM::LLVMFunctionType::get(
292 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
293 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
294 moduleOp, loc, rewriter, "__assertfail", assertfailType);
295 assertfailDecl.setPassthroughAttr(
296 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
297
298 // Split blocks and insert conditional branch.
299 // ^before:
300 // ...
301 // cf.cond_br %condition, ^after, ^assert
302 // ^assert:
303 // cf.assert
304 // cf.br ^after
305 // ^after:
306 // ...
307 Block *beforeBlock = assertOp->getBlock();
308 Block *assertBlock =
309 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
310 Block *afterBlock =
311 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
312 rewriter.setInsertionPointToEnd(beforeBlock);
313 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
314 assertBlock);
315 rewriter.setInsertionPointToEnd(assertBlock);
316 cf::BranchOp::create(rewriter, loc, afterBlock);
317
318 // Continue cf.assert lowering.
319 rewriter.setInsertionPoint(assertOp);
320
321 // Populate file name, file number and function name from the location of
322 // the AssertOp.
323 StringRef fileName = "(unknown)";
324 StringRef funcName = "(unknown)";
325 int32_t fileLine = 0;
326 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
327 loc = callSiteLoc.getCallee();
328 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
329 fileName = fileLineColLoc.getFilename().strref();
330 fileLine = fileLineColLoc.getStartLine();
331 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
332 funcName = nameLoc.getName().strref();
333 if (auto fileLineColLoc =
334 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
335 fileName = fileLineColLoc.getFilename().strref();
336 fileLine = fileLineColLoc.getStartLine();
337 }
338 }
339
340 // Create constants.
341 auto getGlobal = [&](LLVM::GlobalOp global) {
342 // Get a pointer to the format string's first element.
343 Value globalPtr = LLVM::AddressOfOp::create(
344 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
345 global.getSymNameAttr());
346 Value start =
347 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
348 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
349 return start;
350 };
351 Value assertMessage = getGlobal(getOrCreateStringConstant(
352 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
353 Value assertFile = getGlobal(getOrCreateStringConstant(
354 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
355 Value assertFunc = getGlobal(getOrCreateStringConstant(
356 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
357 Value assertLine =
358 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
359 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
360
361 // Insert function call to __assertfail.
362 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
363 assertFunc, c1};
364 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
365 arguments);
366 return success();
367 }
368};
369
370/// Import the GPU Ops to NVVM Patterns.
371#include "GPUToNVVM.cpp.inc"
372
373/// A pass that replaces all occurrences of GPU device operations with their
374/// corresponding NVVM equivalent.
375///
376/// This pass only handles device code and is not meant to be run on GPU host
377/// code.
378struct LowerGpuOpsToNVVMOpsPass final
379 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
380 using Base::Base;
381
382 void getDependentDialects(DialectRegistry &registry) const override {
383 Base::getDependentDialects(registry);
385 }
386
387 void runOnOperation() override {
388 gpu::GPUModuleOp m = getOperation();
389
390 // Request C wrapper emission.
391 for (auto func : m.getOps<func::FuncOp>()) {
392 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
393 UnitAttr::get(&getContext()));
394 }
395
396 // Customize the bitwidth used for the device side index computations.
397 LowerToLLVMOptions options(
398 m.getContext(),
399 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
400 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
401 options.overrideIndexBitwidth(indexBitwidth);
402 options.useBarePtrCallConv = useBarePtrCallConv;
403
404 // Apply in-dialect lowering. In-dialect lowering will replace
405 // ops which need to be lowered further, which is not supported by a
406 // single conversion pass.
407 {
408 RewritePatternSet patterns(m.getContext());
410 // Transform N-D vector.from_elements to 1-D vector.from_elements before
411 // conversion.
412 vector::populateVectorFromElementsUnrollPatterns(patterns);
413 if (failed(applyPatternsGreedily(m, std::move(patterns))))
414 return signalPassFailure();
415 }
416
417 LLVMTypeConverter converter(m.getContext(), options);
419 RewritePatternSet llvmPatterns(m.getContext());
420 LLVMConversionTarget target(getContext());
421
422 // Set higher benefit, so patterns will run before generic LLVM lowering.
423 populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
424 /*benefit=*/10);
425
426 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
427 allowedDialects.end());
428 for (Dialect *dialect : getContext().getLoadedDialects()) {
429 // Skip math patterns as nvvm needs custom math lowering.
430 if (isa<math::MathDialect>(dialect))
431 continue;
432
433 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
434 // Empty `allowedDialectsSet` means all dialects are allowed.
435 if (!allowedDialectsSet.empty() && !allowed)
436 continue;
437
438 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
439 if (!iface) {
440 // Error out if dialect was explicily specified but doesn't implement
441 // conversion interface.
442 if (allowed) {
443 m.emitError()
444 << "dialect does not implement ConvertToLLVMPatternInterface: "
445 << dialect->getNamespace();
446 return signalPassFailure();
447 }
448 continue;
449 }
450
451 iface->populateConvertToLLVMConversionPatterns(target, converter,
452 llvmPatterns);
453 }
454
455 populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
456 if (this->hasRedux)
457 populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
459 ConversionConfig config;
460 config.allowPatternRollback = allowPatternRollback;
461 if (failed(
462 applyPartialConversion(m, target, std::move(llvmPatterns), config)))
463 signalPassFailure();
464 }
465};
466
467} // namespace
468
470 target.addIllegalOp<func::FuncOp>();
471 target.addIllegalOp<cf::AssertOp>();
472 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
473 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
474 target.addIllegalDialect<gpu::GPUDialect>();
475 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
476 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
477 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
478 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
479 LLVM::SincosOp, LLVM::SqrtOp>();
480
481 // TODO: Remove once we support replacing non-root ops.
482 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
483}
484
487
488 // Lowering for MMAMatrixType.
489 converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
490 return convertMMAToLLVMType(type);
491 });
492}
493
495 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
496 PatternBenefit benefit) {
497 patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
498}
499
501 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
502 PatternBenefit benefit) {
505
506 // TODO: Pass benefit to generated patterns.
507 populateWithGenerated(patterns);
508
509 patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
510 converter, benefit);
511 patterns.add<
512 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
513 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
514 converter, IndexKind::Block, IntrType::Id, benefit);
515 patterns.add<
516 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
517 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
518 converter, IndexKind::Block, IntrType::Dim, benefit);
519 patterns.add<
520 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
521 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
522 converter, IndexKind::Other, IntrType::Id, benefit);
524 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
525 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
526 benefit);
528 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
529 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
530 converter, IndexKind::Cluster, IntrType::Id, benefit);
532 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
533 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
534 converter, IndexKind::Cluster, IntrType::Dim, benefit);
536 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
537 converter, IndexKind::Grid, IntrType::Id, benefit);
539 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
540 converter, IndexKind::Grid, IntrType::Dim, benefit);
541 patterns.add<GPULaneIdOpToNVVM, GPUBallotOpToNVVM, GPUShuffleOpLowering,
542 GPUReturnOpLowering>(converter, benefit);
543
545 converter, NVVM::kSharedMemoryAlignmentBit, benefit);
546
547 // Explicitly drop memory space when lowering private memory
548 // attributions since NVVM models it as `alloca`s in the default
549 // memory space and does not support `alloca`s with addrspace(5).
550 patterns.add<GPUFuncOpLowering>(
551 converter,
553 /*allocaAddrSpace=*/0,
554 /*workgroupAddrSpace=*/
555 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
556 StringAttr::get(&converter.getContext(),
557 NVVM::NVVMDialect::getKernelFuncAttrName()),
558 StringAttr::get(&converter.getContext(),
559 NVVM::NVVMDialect::getMaxntidAttrName()),
560 StringAttr::get(&converter.getContext(),
561 NVVM::NVVMDialect::getClusterDimAttrName())},
562 benefit);
563
564 populateLibDeviceConversionPatterns(converter, patterns, benefit);
565}
566
567//===----------------------------------------------------------------------===//
568// NVVMTargetAttr convert to LLVM attr interface
569//===----------------------------------------------------------------------===//
570
571namespace {
572struct NVVMTargetConvertToLLVMAttrInterface
573 : public ConvertToLLVMAttrInterface::ExternalModel<
574 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
575 /// Configure GPU to NVVM.
576 void populateConvertToLLVMConversionPatterns(
578 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
579};
580} // namespace
581
582void NVVMTargetConvertToLLVMAttrInterface::
583 populateConvertToLLVMConversionPatterns(Attribute attr,
585 LLVMTypeConverter &typeConverter,
586 RewritePatternSet &patterns) const {
588 configureGpuToNVVMTypeConverter(typeConverter);
589 populateGpuToNVVMConversionPatterns(typeConverter, patterns);
590}
591
593 registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
594 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
595 });
596}
return success()
b getContext())
constexpr int kWarpSize
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to NVVM libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.