MLIR 23.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
33#include "mlir/IR/SymbolTable.h"
37
41#include <optional>
42
43namespace mlir {
44#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
45#include "mlir/Conversion/Passes.h.inc"
46} // namespace mlir
47
48using namespace mlir;
49
50namespace {
51
52/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
53static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
54 switch (mode) {
55 case gpu::ShuffleMode::XOR:
56 return NVVM::ShflKind::bfly;
57 case gpu::ShuffleMode::UP:
58 return NVVM::ShflKind::up;
59 case gpu::ShuffleMode::DOWN:
60 return NVVM::ShflKind::down;
61 case gpu::ShuffleMode::IDX:
62 return NVVM::ShflKind::idx;
63 }
64 llvm_unreachable("unknown shuffle mode");
65}
66
67static std::optional<NVVM::ReductionKind>
68convertToNVVMReductionKind(gpu::AllReduceOperation mode) {
69 switch (mode) {
70 case gpu::AllReduceOperation::ADD:
71 return NVVM::ReductionKind::ADD;
72 case gpu::AllReduceOperation::MUL:
73 return std::nullopt;
74 case gpu::AllReduceOperation::MINSI:
75 return NVVM::ReductionKind::MIN;
76 case gpu::AllReduceOperation::MINUI:
77 return std::nullopt;
78 case gpu::AllReduceOperation::MINNUMF:
79 return NVVM::ReductionKind::MIN;
80 case gpu::AllReduceOperation::MAXSI:
81 return NVVM::ReductionKind::MAX;
82 case gpu::AllReduceOperation::MAXUI:
83 return std::nullopt;
84 case gpu::AllReduceOperation::MAXNUMF:
85 return NVVM::ReductionKind::MAX;
86 case gpu::AllReduceOperation::AND:
87 return NVVM::ReductionKind::AND;
88 case gpu::AllReduceOperation::OR:
89 return NVVM::ReductionKind::OR;
90 case gpu::AllReduceOperation::XOR:
91 return NVVM::ReductionKind::XOR;
92 case gpu::AllReduceOperation::MINIMUMF:
93 case gpu::AllReduceOperation::MAXIMUMF:
94 return std::nullopt;
95 }
96 return std::nullopt;
97}
98
99static constexpr llvm::StringLiteral kNVVMNamedBarrierIdPrefix =
100 "__named_barrier_id";
101static constexpr int32_t kNVVMFirstNamedBarrierId = 1;
102static constexpr int32_t kNVVMLastNamedBarrierId = 15;
103static constexpr int32_t kNVVMWarpSize = 32;
104
105static FailureOr<StringAttr>
106createNVVMNamedBarrierIdGlobal(gpu::InitializeNamedBarrierOp op,
107 ConversionPatternRewriter &rewriter) {
108 auto funcOp = op->getParentOfType<FunctionOpInterface>();
109 if (!funcOp) {
110 op.emitOpError("must be inside a function-like op");
111 return failure();
112 }
113 Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
114 if (!symbolTableOp) {
115 op.emitOpError(
116 "enclosing function-like op must have a symbol-table parent");
117 return failure();
118 }
119
120 int32_t numNamedBarriers = 0;
121 for (auto globalOp :
122 symbolTableOp->getRegion(0).front().getOps<LLVM::GlobalOp>())
123 if (globalOp.getSymName().starts_with(kNVVMNamedBarrierIdPrefix))
124 ++numNamedBarriers;
125
126 int32_t barrierId = kNVVMFirstNamedBarrierId + numNamedBarriers;
127 if (barrierId > kNVVMLastNamedBarrierId) {
128 op.emitOpError("NVVM supports at most 15 named barriers per CTA");
129 return failure();
130 }
131
132 OpBuilder detachedBuilder(rewriter.getContext());
133 Type i32 = rewriter.getI32Type();
134 auto globalOp = LLVM::GlobalOp::create(
135 detachedBuilder, op.getLoc(), i32, /*isConstant=*/true,
136 LLVM::Linkage::Internal, kNVVMNamedBarrierIdPrefix,
137 rewriter.getI32IntegerAttr(barrierId), /*alignment=*/0,
138 /*addrSpace=*/0);
139 return SymbolTable(symbolTableOp).insert(globalOp);
140}
141
142/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
143/// must be run by the entire subgroup, otherwise it is undefined behaviour.
144struct GPUSubgroupReduceOpLowering
145 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
146 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
147 LogicalResult
148
149 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter) const override {
151 if (op.getClusterSize())
152 return rewriter.notifyMatchFailure(
153 op, "lowering for clustered reduce not implemented");
154
155 if (!op.getUniform())
156 return rewriter.notifyMatchFailure(
157 op, "cannot be lowered to redux as the op must be run "
158 "uniformly (entire subgroup).");
159 if (!op.getValue().getType().isInteger(32))
160 return rewriter.notifyMatchFailure(op, "unsupported data type");
161
162 std::optional<NVVM::ReductionKind> mode =
163 convertToNVVMReductionKind(op.getOp());
164 if (!mode.has_value())
165 return rewriter.notifyMatchFailure(
166 op, "unsupported reduction mode for redux");
167
168 Location loc = op->getLoc();
169 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
170 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
171
172 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
173 op.getValue(), mode.value(), offset);
174
175 rewriter.replaceOp(op, reduxOp->getResult(0));
176 return success();
177 }
178};
179
180struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
181 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
182
183 /// Lowers a shuffle to the corresponding NVVM op.
184 ///
185 /// Convert the `width` argument into an activeMask (a bitmask which specifies
186 /// which threads participate in the shuffle) and a maskAndClamp (specifying
187 /// the highest lane which participates in the shuffle).
188 ///
189 /// %one = llvm.constant(1 : i32) : i32
190 /// %minus_one = llvm.constant(-1 : i32) : i32
191 /// %thirty_two = llvm.constant(32 : i32) : i32
192 /// %num_lanes = llvm.sub %thirty_two, %width : i32
193 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
194 /// %mask_and_clamp = llvm.sub %width, %one : i32
195 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
196 /// %mask_and_clamp : !llvm<"{ float, i1 }">
197 /// %shfl_value = llvm.extractvalue %shfl[0] :
198 /// !llvm<"{ float, i1 }">
199 /// %shfl_pred = llvm.extractvalue %shfl[1] :
200 /// !llvm<"{ float, i1 }">
201 LogicalResult
202 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override {
204 Location loc = op->getLoc();
205
206 auto valueTy = adaptor.getValue().getType();
207 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
208 auto predTy = IntegerType::get(rewriter.getContext(), 1);
209
210 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
211 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
212 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
213 Value numLeadInactiveLane = LLVM::SubOp::create(
214 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
215 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
216 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
217 numLeadInactiveLane);
218 Value maskAndClamp;
219 if (op.getMode() == gpu::ShuffleMode::UP) {
220 // Clamp lane: `32 - activeWidth`
221 maskAndClamp = numLeadInactiveLane;
222 } else {
223 // Clamp lane: `activeWidth - 1`
224 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
225 adaptor.getWidth(), one);
226 }
227
228 bool predIsUsed = !op->getResult(1).use_empty();
229 UnitAttr returnValueAndIsValidAttr = nullptr;
230 Type resultTy = valueTy;
231 if (predIsUsed) {
232 returnValueAndIsValidAttr = rewriter.getUnitAttr();
233 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
234 {valueTy, predTy});
235 }
236 Value shfl = NVVM::ShflOp::create(
237 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
238 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
239 returnValueAndIsValidAttr);
240 if (predIsUsed) {
241 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
242 Value isActiveSrcLane =
243 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
244 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
245 } else {
246 rewriter.replaceOp(op, {shfl, nullptr});
247 }
248 return success();
249 }
250};
251
252struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
253 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
254
255 LogicalResult
256 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
257 ConversionPatternRewriter &rewriter) const override {
258 auto loc = op->getLoc();
259 MLIRContext *context = rewriter.getContext();
260 LLVM::ConstantRangeAttr bounds = nullptr;
261 if (std::optional<APInt> upperBound = op.getUpperBound())
262 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
263 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
264 else
265 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
266 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
267 Value newOp =
268 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
269 // Truncate or extend the result depending on the index bitwidth specified
270 // by the LLVMTypeConverter options.
271 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
272 if (indexBitwidth > 32) {
273 newOp = LLVM::SExtOp::create(
274 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
275 } else if (indexBitwidth < 32) {
276 newOp = LLVM::TruncOp::create(
277 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
278 }
279 rewriter.replaceOp(op, {newOp});
280 return success();
281 }
282};
283
284struct GPUBallotOpToNVVM : public ConvertOpToLLVMPattern<gpu::BallotOp> {
285 using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
286
287 LogicalResult
288 matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 Location loc = op->getLoc();
291 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
292 auto intType = cast<IntegerType>(op.getType());
293 unsigned width = intType.getWidth();
294
295 // NVVM ballot natively returns i32. For i64 results, zero-extend since
296 // NVIDIA warps have exactly 32 threads, so upper 32 bits are always zero.
297 if (width != 32 && width != 64)
298 return rewriter.notifyMatchFailure(
299 op, "nvvm.vote.sync ballot only supports i32 and i64 result types");
300
301 // Use full mask (-1) so all 32 lanes participate in the ballot.
302 Value mask = LLVM::ConstantOp::create(rewriter, loc, int32Type,
303 rewriter.getI32IntegerAttr(-1));
304
305 auto voteKind = NVVM::VoteSyncKindAttr::get(rewriter.getContext(),
306 NVVM::VoteSyncKind::ballot);
307 Value result = NVVM::VoteSyncOp::create(rewriter, loc, int32Type, mask,
308 adaptor.getPredicate(), voteKind);
309
310 if (width == 64)
311 result = LLVM::ZExtOp::create(rewriter, loc, op.getType(), result);
312
313 rewriter.replaceOp(op, result);
314 return success();
315 }
316};
317
318/// Lowering of cf.assert into a conditional __assertfail.
319struct AssertOpToAssertfailLowering
320 : public ConvertOpToLLVMPattern<cf::AssertOp> {
321 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
322
323 LogicalResult
324 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter) const override {
326 MLIRContext *ctx = rewriter.getContext();
327 Location loc = assertOp.getLoc();
328 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
329 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
330 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
331 Type ptrType = LLVM::LLVMPointerType::get(ctx);
332 Type voidType = LLVM::LLVMVoidType::get(ctx);
333
334 // Find or create __assertfail function declaration.
335 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
336 auto assertfailType = LLVM::LLVMFunctionType::get(
337 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
338 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
339 moduleOp, loc, rewriter, "__assertfail", assertfailType);
340 assertfailDecl.setPassthroughAttr(
341 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
342
343 // Split blocks and insert conditional branch.
344 // ^before:
345 // ...
346 // cf.cond_br %condition, ^after, ^assert
347 // ^assert:
348 // cf.assert
349 // cf.br ^after
350 // ^after:
351 // ...
352 Block *beforeBlock = assertOp->getBlock();
353 Block *assertBlock =
354 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
355 Block *afterBlock =
356 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
357 rewriter.setInsertionPointToEnd(beforeBlock);
358 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
359 assertBlock);
360 rewriter.setInsertionPointToEnd(assertBlock);
361 cf::BranchOp::create(rewriter, loc, afterBlock);
362
363 // Continue cf.assert lowering.
364 rewriter.setInsertionPoint(assertOp);
365
366 // Populate file name, file number and function name from the location of
367 // the AssertOp.
368 StringRef fileName = "(unknown)";
369 StringRef funcName = "(unknown)";
370 int32_t fileLine = 0;
371 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
372 loc = callSiteLoc.getCallee();
373 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
374 fileName = fileLineColLoc.getFilename().strref();
375 fileLine = fileLineColLoc.getStartLine();
376 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
377 funcName = nameLoc.getName().strref();
378 if (auto fileLineColLoc =
379 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
380 fileName = fileLineColLoc.getFilename().strref();
381 fileLine = fileLineColLoc.getStartLine();
382 }
383 }
384
385 // Create constants.
386 auto getGlobal = [&](LLVM::GlobalOp global) {
387 // Get a pointer to the format string's first element.
388 Value globalPtr = LLVM::AddressOfOp::create(
389 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
390 global.getSymNameAttr());
391 Value start =
392 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
393 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
394 return start;
395 };
396 Value assertMessage = getGlobal(getOrCreateStringConstant(
397 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
398 Value assertFile = getGlobal(getOrCreateStringConstant(
399 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
400 Value assertFunc = getGlobal(getOrCreateStringConstant(
401 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
402 Value assertLine =
403 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
404 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
405
406 // Insert function call to __assertfail.
407 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
408 assertFunc, c1};
409 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
410 arguments);
411 return success();
412 }
413};
414
415struct GPUBarrierOpToNVVMLowering final
416 : public ConvertOpToLLVMPattern<gpu::BarrierOp> {
418
419 LogicalResult
420 matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override {
422 if (Value namedBarrier = adaptor.getNamedBarrier()) {
423 Location loc = op.getLoc();
424 Value barrierId =
425 LLVM::ExtractValueOp::create(rewriter, loc, namedBarrier, 0);
426 Value numberOfThreads =
427 LLVM::ExtractValueOp::create(rewriter, loc, namedBarrier, 1);
428 NVVM::BarrierOp::create(rewriter, loc, barrierId, numberOfThreads);
429 rewriter.eraseOp(op);
430 return success();
431 }
432
433 gpu::BarrierScope scope = op.getScope();
434 switch (scope) {
435 case gpu::BarrierScope::Workgroup:
436 rewriter.replaceOpWithNewOp<NVVM::BarrierOp>(op);
437 return success();
438 case gpu::BarrierScope::Subgroup: {
439 // Emit __syncwarp(0xFFFFFFFF) for full-warp sync.
440 Value mask =
441 LLVM::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI32Type(),
442 rewriter.getI32IntegerAttr(0xFFFFFFFF));
443 rewriter.replaceOpWithNewOp<NVVM::SyncWarpOp>(op, mask);
444 return success();
445 }
446 default:
447 return rewriter.notifyMatchFailure(
448 op, "unsupported scope for NVVM barrier lowering");
449 }
450 }
451};
452
453struct GPUInitializeNamedBarrierOpToNVVMLowering final
454 : public ConvertOpToLLVMPattern<gpu::InitializeNamedBarrierOp> {
456
457 LogicalResult
458 matchAndRewrite(gpu::InitializeNamedBarrierOp op,
459 gpu::InitializeNamedBarrierOp::Adaptor adaptor,
460 ConversionPatternRewriter &rewriter) const override {
461 Location loc = op.getLoc();
462 MLIRContext *ctx = rewriter.getContext();
463 Type i32 = rewriter.getI32Type();
464 Type namedBarrierType =
465 getTypeConverter()->convertType(op.getResult().getType());
466 if (!namedBarrierType)
467 return rewriter.notifyMatchFailure(op, "failed to convert result type");
468
469 FailureOr<StringAttr> maybeGlobalName =
470 createNVVMNamedBarrierIdGlobal(op, rewriter);
471 if (failed(maybeGlobalName))
472 return failure();
473
474 auto addressOf = LLVM::AddressOfOp::create(
475 rewriter, loc, LLVM::LLVMPointerType::get(ctx), *maybeGlobalName);
476 Value barrierId =
477 LLVM::LoadOp::create(rewriter, loc, i32, addressOf.getResult());
478
479 Value warpSize = LLVM::ConstantOp::create(
480 rewriter, loc, i32, rewriter.getI32IntegerAttr(kNVVMWarpSize));
481 Value numberOfThreads =
482 LLVM::MulOp::create(rewriter, loc, adaptor.getMemberCount(), warpSize);
483
484 Value namedBarrier =
485 LLVM::PoisonOp::create(rewriter, loc, namedBarrierType);
486 DenseI64ArrayAttr barrierIdPos = rewriter.getDenseI64ArrayAttr({0});
487 DenseI64ArrayAttr numberOfThreadsPos = rewriter.getDenseI64ArrayAttr({1});
488 namedBarrier = LLVM::InsertValueOp::create(rewriter, loc, namedBarrier,
489 barrierId, barrierIdPos);
490 namedBarrier = LLVM::InsertValueOp::create(
491 rewriter, loc, namedBarrier, numberOfThreads, numberOfThreadsPos);
492 rewriter.replaceOp(op, namedBarrier);
493 return success();
494 }
495};
496
497/// A pass that replaces all occurrences of GPU device operations with their
498/// corresponding NVVM equivalent.
499///
500/// This pass only handles device code and is not meant to be run on GPU host
501/// code.
502struct LowerGpuOpsToNVVMOpsPass final
503 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
504 using Base::Base;
505
506 void getDependentDialects(DialectRegistry &registry) const override {
507 Base::getDependentDialects(registry);
509 }
510
511 void runOnOperation() override {
512 gpu::GPUModuleOp m = getOperation();
513
514 // Request C wrapper emission.
515 for (auto func : m.getOps<func::FuncOp>()) {
516 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
517 UnitAttr::get(&getContext()));
518 }
519
520 // Customize the bitwidth used for the device side index computations.
521 LowerToLLVMOptions options(
522 m.getContext(),
523 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
524 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
525 options.overrideIndexBitwidth(indexBitwidth);
526 options.useBarePtrCallConv = useBarePtrCallConv;
527
528 // Apply in-dialect lowering. In-dialect lowering will replace
529 // ops which need to be lowered further, which is not supported by a
530 // single conversion pass.
531 {
532 RewritePatternSet patterns(m.getContext());
534 // Transform N-D vector.from_elements to 1-D vector.from_elements before
535 // conversion.
536 vector::populateVectorFromElementsUnrollPatterns(patterns);
537 if (failed(applyPatternsGreedily(m, std::move(patterns))))
538 return signalPassFailure();
539 }
540
541 LLVMTypeConverter converter(m.getContext(), options);
543 RewritePatternSet llvmPatterns(m.getContext());
544 LLVMConversionTarget target(getContext());
545
546 // Set higher benefit, so patterns will run before generic LLVM lowering.
547 populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
548 /*benefit=*/10);
549
550 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
551 allowedDialects.end());
552 for (Dialect *dialect : getContext().getLoadedDialects()) {
553 // Skip math patterns as nvvm needs custom math lowering.
554 if (isa<math::MathDialect>(dialect))
555 continue;
556
557 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
558 // Empty `allowedDialectsSet` means all dialects are allowed.
559 if (!allowedDialectsSet.empty() && !allowed)
560 continue;
561
562 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
563 if (!iface) {
564 // Error out if dialect was explicily specified but doesn't implement
565 // conversion interface.
566 if (allowed) {
567 m.emitError()
568 << "dialect does not implement ConvertToLLVMPatternInterface: "
569 << dialect->getNamespace();
570 return signalPassFailure();
571 }
572 continue;
573 }
574
575 iface->populateConvertToLLVMConversionPatterns(target, converter,
576 llvmPatterns);
577 }
578
579 populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
580 if (this->hasRedux)
581 populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
583 ConversionConfig config;
584 config.allowPatternRollback = allowPatternRollback;
585 if (failed(
586 applyPartialConversion(m, target, std::move(llvmPatterns), config)))
587 signalPassFailure();
588 }
589};
590
591} // namespace
592
594 target.addIllegalOp<func::FuncOp>();
595 target.addIllegalOp<cf::AssertOp>();
596 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
597 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
598 target.addIllegalDialect<gpu::GPUDialect>();
599 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
600 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
601 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
602 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
603 LLVM::SincosOp, LLVM::SqrtOp>();
604
605 // TODO: Remove once we support replacing non-root ops.
606 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
607}
608
611
612 converter.addConversion([&](gpu::NamedBarrierType type) -> Type {
613 Type i32 = IntegerType::get(type.getContext(), 32);
614 return LLVM::LLVMStructType::getLiteral(type.getContext(), {i32, i32});
615 });
616
617 // Lowering for MMAMatrixType.
618 converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
619 return convertMMAToLLVMType(type);
620 });
621}
622
624 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
625 PatternBenefit benefit) {
626 patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
627}
628
630 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
631 PatternBenefit benefit) {
634
635 patterns.add<GPUBarrierOpToNVVMLowering,
636 GPUInitializeNamedBarrierOpToNVVMLowering,
637 GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
638 converter, benefit);
639 patterns.add<
640 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
641 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
642 converter, IndexKind::Block, IntrType::Id, benefit);
643 patterns.add<
644 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
645 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
646 converter, IndexKind::Block, IntrType::Dim, benefit);
647 patterns.add<
648 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
649 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
650 converter, IndexKind::Other, IntrType::Id, benefit);
652 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
653 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
654 benefit);
656 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
657 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
658 converter, IndexKind::Cluster, IntrType::Id, benefit);
660 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
661 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
662 converter, IndexKind::Cluster, IntrType::Dim, benefit);
664 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
665 converter, IndexKind::Grid, IntrType::Id, benefit);
667 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
668 converter, IndexKind::Grid, IntrType::Dim, benefit);
669 patterns.add<GPULaneIdOpToNVVM, GPUBallotOpToNVVM, GPUShuffleOpLowering,
670 GPUReturnOpLowering>(converter, benefit);
671
673 converter, NVVM::kSharedMemoryAlignmentBit, benefit);
674
675 // Explicitly drop memory space when lowering private memory
676 // attributions since NVVM models it as `alloca`s in the default
677 // memory space and does not support `alloca`s with addrspace(5).
678 patterns.add<GPUFuncOpLowering>(
679 converter,
681 /*allocaAddrSpace=*/0,
682 /*workgroupAddrSpace=*/
683 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
684 StringAttr::get(&converter.getContext(),
685 NVVM::NVVMDialect::getKernelFuncAttrName()),
686 StringAttr::get(&converter.getContext(),
687 NVVM::NVVMDialect::getMaxntidAttrName()),
688 StringAttr::get(&converter.getContext(),
689 NVVM::NVVMDialect::getClusterDimAttrName())},
690 benefit);
691
692 populateLibDeviceConversionPatterns(converter, patterns, benefit);
693}
694
695//===----------------------------------------------------------------------===//
696// NVVMTargetAttr convert to LLVM attr interface
697//===----------------------------------------------------------------------===//
698
699namespace {
700struct NVVMTargetConvertToLLVMAttrInterface
701 : public ConvertToLLVMAttrInterface::ExternalModel<
702 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
703 /// Configure GPU to NVVM.
704 void populateConvertToLLVMConversionPatterns(
706 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
707};
708} // namespace
709
710void NVVMTargetConvertToLLVMAttrInterface::
711 populateConvertToLLVMConversionPatterns(Attribute attr,
713 LLVMTypeConverter &typeConverter,
714 RewritePatternSet &patterns) const {
716 configureGpuToNVVMTypeConverter(typeConverter);
717 populateGpuToNVVMConversionPatterns(typeConverter, patterns);
718}
719
721 registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
722 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
723 });
724}
return success()
b getContext())
constexpr int kWarpSize
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:273
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Block & front()
Definition Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:139
constexpr int kSharedMemoryAlignmentBit
Definition NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to NVVM libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.