MLIR 23.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
33#include "mlir/IR/SymbolTable.h"
37
41#include <optional>
42
43namespace mlir {
44#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
45#include "mlir/Conversion/Passes.h.inc"
46} // namespace mlir
47
48using namespace mlir;
49
50namespace {
51
52/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
53static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
54 switch (mode) {
55 case gpu::ShuffleMode::XOR:
56 return NVVM::ShflKind::bfly;
57 case gpu::ShuffleMode::UP:
58 return NVVM::ShflKind::up;
59 case gpu::ShuffleMode::DOWN:
60 return NVVM::ShflKind::down;
61 case gpu::ShuffleMode::IDX:
62 return NVVM::ShflKind::idx;
63 }
64 llvm_unreachable("unknown shuffle mode");
65}
66
67static std::optional<NVVM::ReductionKind>
68convertToNVVMReductionKind(gpu::AllReduceOperation mode) {
69 switch (mode) {
70 case gpu::AllReduceOperation::ADD:
71 return NVVM::ReductionKind::ADD;
72 case gpu::AllReduceOperation::MUL:
73 return std::nullopt;
74 case gpu::AllReduceOperation::MINSI:
75 return NVVM::ReductionKind::MIN;
76 case gpu::AllReduceOperation::MINUI:
77 return std::nullopt;
78 case gpu::AllReduceOperation::MINNUMF:
79 return NVVM::ReductionKind::MIN;
80 case gpu::AllReduceOperation::MAXSI:
81 return NVVM::ReductionKind::MAX;
82 case gpu::AllReduceOperation::MAXUI:
83 return std::nullopt;
84 case gpu::AllReduceOperation::MAXNUMF:
85 return NVVM::ReductionKind::MAX;
86 case gpu::AllReduceOperation::AND:
87 return NVVM::ReductionKind::AND;
88 case gpu::AllReduceOperation::OR:
89 return NVVM::ReductionKind::OR;
90 case gpu::AllReduceOperation::XOR:
91 return NVVM::ReductionKind::XOR;
92 case gpu::AllReduceOperation::MINIMUMF:
93 case gpu::AllReduceOperation::MAXIMUMF:
94 return std::nullopt;
95 }
96 return std::nullopt;
97}
98
99static constexpr llvm::StringLiteral kNVVMNamedBarrierIdPrefix =
100 "__named_barrier_id";
101static constexpr int32_t kNVVMFirstNamedBarrierId = 1;
102static constexpr int32_t kNVVMLastNamedBarrierId = 15;
103static constexpr int32_t kNVVMWarpSize = 32;
104
105static FailureOr<StringAttr>
106createNVVMNamedBarrierIdGlobal(gpu::InitializeNamedBarrierOp op,
107 ConversionPatternRewriter &rewriter) {
108 auto funcOp = op->getParentOfType<FunctionOpInterface>();
109 if (!funcOp) {
110 op.emitOpError("must be inside a function-like op");
111 return failure();
112 }
113 Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
114 if (!symbolTableOp) {
115 op.emitOpError(
116 "enclosing function-like op must have a symbol-table parent");
117 return failure();
118 }
119
120 int32_t numNamedBarriers = 0;
121 for (auto globalOp :
122 symbolTableOp->getRegion(0).front().getOps<LLVM::GlobalOp>())
123 if (globalOp.getSymName().starts_with(kNVVMNamedBarrierIdPrefix))
124 ++numNamedBarriers;
125
126 int32_t barrierId = kNVVMFirstNamedBarrierId + numNamedBarriers;
127 if (barrierId > kNVVMLastNamedBarrierId) {
128 op.emitOpError("NVVM supports at most 15 named barriers per CTA");
129 return failure();
130 }
131
132 OpBuilder detachedBuilder(rewriter.getContext());
133 Type i32 = rewriter.getI32Type();
134 auto globalOp = LLVM::GlobalOp::create(
135 detachedBuilder, op.getLoc(), i32, /*isConstant=*/true,
136 LLVM::Linkage::Internal, kNVVMNamedBarrierIdPrefix,
137 rewriter.getI32IntegerAttr(barrierId), /*alignment=*/0,
138 /*addrSpace=*/0);
139 return SymbolTable(symbolTableOp).insert(globalOp);
140}
141
142/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
143/// must be run by the entire subgroup, otherwise it is undefined behaviour.
144struct GPUSubgroupReduceOpLowering
145 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
146 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
147 LogicalResult
148
149 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter) const override {
151 if (op.getClusterSize())
152 return rewriter.notifyMatchFailure(
153 op, "lowering for clustered reduce not implemented");
154
155 if (!op.getUniform())
156 return rewriter.notifyMatchFailure(
157 op, "cannot be lowered to redux as the op must be run "
158 "uniformly (entire subgroup).");
159 if (!op.getValue().getType().isInteger(32))
160 return rewriter.notifyMatchFailure(op, "unsupported data type");
161
162 std::optional<NVVM::ReductionKind> mode =
163 convertToNVVMReductionKind(op.getOp());
164 if (!mode.has_value())
165 return rewriter.notifyMatchFailure(
166 op, "unsupported reduction mode for redux");
167
168 Location loc = op->getLoc();
169 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
170 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
171
172 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
173 op.getValue(), mode.value(), offset);
174
175 rewriter.replaceOp(op, reduxOp->getResult(0));
176 return success();
177 }
178};
179
180struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
181 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
182
183 /// Lowers a shuffle to the corresponding NVVM op.
184 ///
185 /// Convert the `width` argument into an activeMask (a bitmask which specifies
186 /// which threads participate in the shuffle) and a maskAndClamp (specifying
187 /// the highest lane which participates in the shuffle).
188 ///
189 /// %one = llvm.constant(1 : i32) : i32
190 /// %minus_one = llvm.constant(-1 : i32) : i32
191 /// %thirty_two = llvm.constant(32 : i32) : i32
192 /// %num_lanes = llvm.sub %thirty_two, %width : i32
193 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
194 /// %mask_and_clamp = llvm.sub %width, %one : i32
195 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
196 /// %mask_and_clamp : !llvm<"{ float, i1 }">
197 /// %shfl_value = llvm.extractvalue %shfl[0] :
198 /// !llvm<"{ float, i1 }">
199 /// %shfl_pred = llvm.extractvalue %shfl[1] :
200 /// !llvm<"{ float, i1 }">
201 LogicalResult
202 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override {
204 Location loc = op->getLoc();
205
206 auto valueTy = adaptor.getValue().getType();
207 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
208 auto predTy = IntegerType::get(rewriter.getContext(), 1);
209
210 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
211 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
212 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
213 Value numLeadInactiveLane = LLVM::SubOp::create(
214 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
215 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
216 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
217 numLeadInactiveLane);
218 Value maskAndClamp;
219 if (op.getMode() == gpu::ShuffleMode::UP) {
220 // Clamp lane: `32 - activeWidth`
221 maskAndClamp = numLeadInactiveLane;
222 } else {
223 // Clamp lane: `activeWidth - 1`
224 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
225 adaptor.getWidth(), one);
226 }
227
228 bool predIsUsed = !op->getResult(1).use_empty();
229 UnitAttr returnValueAndIsValidAttr = nullptr;
230 Type resultTy = valueTy;
231 if (predIsUsed) {
232 returnValueAndIsValidAttr = rewriter.getUnitAttr();
233 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
234 {valueTy, predTy});
235 }
236 Value shfl = NVVM::ShflOp::create(
237 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
238 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
239 returnValueAndIsValidAttr);
240 if (predIsUsed) {
241 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
242 Value isActiveSrcLane =
243 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
244 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
245 } else {
246 rewriter.replaceOp(op, {shfl, nullptr});
247 }
248 return success();
249 }
250};
251
252struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
253 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
254
255 LogicalResult
256 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
257 ConversionPatternRewriter &rewriter) const override {
258 auto loc = op->getLoc();
259 MLIRContext *context = rewriter.getContext();
260 LLVM::ConstantRangeAttr bounds = nullptr;
261 if (std::optional<APInt> upperBound = op.getUpperBound())
262 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
263 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
264 else
265 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
266 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
267 Value newOp =
268 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
269 // Truncate or extend the result depending on the index bitwidth specified
270 // by the LLVMTypeConverter options.
271 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
272 if (indexBitwidth > 32) {
273 newOp = LLVM::SExtOp::create(
274 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
275 } else if (indexBitwidth < 32) {
276 newOp = LLVM::TruncOp::create(
277 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
278 }
279 rewriter.replaceOp(op, {newOp});
280 return success();
281 }
282};
283
284struct GPUBallotOpToNVVM : public ConvertOpToLLVMPattern<gpu::BallotOp> {
285 using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
286
287 LogicalResult
288 matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 Location loc = op->getLoc();
291 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
292 auto intType = cast<IntegerType>(op.getType());
293 unsigned width = intType.getWidth();
294
295 // NVVM ballot natively returns i32. For i64 results, zero-extend since
296 // NVIDIA warps have exactly 32 threads, so upper 32 bits are always zero.
297 if (width != 32 && width != 64)
298 return rewriter.notifyMatchFailure(
299 op, "nvvm.vote.sync ballot only supports i32 and i64 result types");
300
301 // Use full mask (-1) so all 32 lanes participate in the ballot.
302 Value mask = LLVM::ConstantOp::create(rewriter, loc, int32Type,
303 rewriter.getI32IntegerAttr(-1));
304
305 auto voteKind = NVVM::VoteSyncKindAttr::get(rewriter.getContext(),
306 NVVM::VoteSyncKind::ballot);
307 Value result = NVVM::VoteSyncOp::create(rewriter, loc, int32Type, mask,
308 adaptor.getPredicate(), voteKind);
309
310 if (width == 64)
311 result = LLVM::ZExtOp::create(rewriter, loc, op.getType(), result);
312
313 rewriter.replaceOp(op, result);
314 return success();
315 }
316};
317
318/// Lowering of cf.assert into a conditional __assertfail.
319struct AssertOpToAssertfailLowering
320 : public ConvertOpToLLVMPattern<cf::AssertOp> {
321 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
322
323 LogicalResult
324 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter) const override {
326 MLIRContext *ctx = rewriter.getContext();
327 Location loc = assertOp.getLoc();
328 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
329 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
330 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
331 Type ptrType = LLVM::LLVMPointerType::get(ctx);
332 Type voidType = LLVM::LLVMVoidType::get(ctx);
333
334 // Find or create __assertfail function declaration.
335 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
336 auto assertfailType = LLVM::LLVMFunctionType::get(
337 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
338 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
339 moduleOp, loc, rewriter, "__assertfail", assertfailType);
340 assertfailDecl.setPassthroughAttr(
341 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
342
343 // Split blocks and insert conditional branch.
344 // ^before:
345 // ...
346 // cf.cond_br %condition, ^after, ^assert
347 // ^assert:
348 // cf.assert
349 // cf.br ^after
350 // ^after:
351 // ...
352 Block *beforeBlock = assertOp->getBlock();
353 Block *assertBlock =
354 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
355 Block *afterBlock =
356 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
357 rewriter.setInsertionPointToEnd(beforeBlock);
358 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
359 assertBlock);
360 rewriter.setInsertionPointToEnd(assertBlock);
361 cf::BranchOp::create(rewriter, loc, afterBlock);
362
363 // Continue cf.assert lowering.
364 rewriter.setInsertionPoint(assertOp);
365
366 // Populate file name, file number and function name from the location of
367 // the AssertOp.
368 StringRef fileName = "(unknown)";
369 StringRef funcName = "(unknown)";
370 int32_t fileLine = 0;
371 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
372 loc = callSiteLoc.getCallee();
373 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
374 fileName = fileLineColLoc.getFilename().strref();
375 fileLine = fileLineColLoc.getStartLine();
376 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
377 funcName = nameLoc.getName().strref();
378 if (auto fileLineColLoc =
379 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
380 fileName = fileLineColLoc.getFilename().strref();
381 fileLine = fileLineColLoc.getStartLine();
382 }
383 }
384
385 // Create constants.
386 auto getGlobal = [&](LLVM::GlobalOp global) {
387 // Get a pointer to the format string's first element.
388 Value globalPtr = LLVM::AddressOfOp::create(
389 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
390 global.getSymNameAttr());
391 Value start =
392 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
393 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
394 return start;
395 };
396 Value assertMessage = getGlobal(getOrCreateStringConstant(
397 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
398 Value assertFile = getGlobal(getOrCreateStringConstant(
399 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
400 Value assertFunc = getGlobal(getOrCreateStringConstant(
401 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
402 Value assertLine =
403 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
404 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
405
406 // Insert function call to __assertfail.
407 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
408 assertFunc, c1};
409 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
410 arguments);
411 return success();
412 }
413};
414
415struct GPUBarrierOpToNVVMLowering final
416 : public ConvertOpToLLVMPattern<gpu::BarrierOp> {
418
419 LogicalResult
420 matchAndRewrite(gpu::BarrierOp op, gpu::BarrierOp::Adaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override {
422 if (Value namedBarrier = adaptor.getNamedBarrier()) {
423 Location loc = op.getLoc();
424 Value barrierId =
425 LLVM::ExtractValueOp::create(rewriter, loc, namedBarrier, 0);
426 Value numberOfThreads =
427 LLVM::ExtractValueOp::create(rewriter, loc, namedBarrier, 1);
428 NVVM::BarrierOp::create(rewriter, loc, barrierId, numberOfThreads,
429 NVVM::BarrierReductionAttr{}, Value{});
430 rewriter.eraseOp(op);
431 return success();
432 }
433
434 gpu::BarrierScope scope = op.getScope();
435 switch (scope) {
436 case gpu::BarrierScope::Workgroup:
437 rewriter.replaceOpWithNewOp<NVVM::BarrierOp>(op);
438 return success();
439 case gpu::BarrierScope::Subgroup: {
440 // Emit __syncwarp(0xFFFFFFFF) for full-warp sync.
441 Value mask =
442 LLVM::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI32Type(),
443 rewriter.getI32IntegerAttr(0xFFFFFFFF));
444 rewriter.replaceOpWithNewOp<NVVM::SyncWarpOp>(op, mask);
445 return success();
446 }
447 default:
448 return rewriter.notifyMatchFailure(
449 op, "unsupported scope for NVVM barrier lowering");
450 }
451 }
452};
453
454struct GPUInitializeNamedBarrierOpToNVVMLowering final
455 : public ConvertOpToLLVMPattern<gpu::InitializeNamedBarrierOp> {
457
458 LogicalResult
459 matchAndRewrite(gpu::InitializeNamedBarrierOp op,
460 gpu::InitializeNamedBarrierOp::Adaptor adaptor,
461 ConversionPatternRewriter &rewriter) const override {
462 Location loc = op.getLoc();
463 MLIRContext *ctx = rewriter.getContext();
464 Type i32 = rewriter.getI32Type();
465 Type namedBarrierType =
466 getTypeConverter()->convertType(op.getResult().getType());
467 if (!namedBarrierType)
468 return rewriter.notifyMatchFailure(op, "failed to convert result type");
469
470 FailureOr<StringAttr> maybeGlobalName =
471 createNVVMNamedBarrierIdGlobal(op, rewriter);
472 if (failed(maybeGlobalName))
473 return failure();
474
475 auto addressOf = LLVM::AddressOfOp::create(
476 rewriter, loc, LLVM::LLVMPointerType::get(ctx), *maybeGlobalName);
477 Value barrierId =
478 LLVM::LoadOp::create(rewriter, loc, i32, addressOf.getResult());
479
480 Value warpSize = LLVM::ConstantOp::create(
481 rewriter, loc, i32, rewriter.getI32IntegerAttr(kNVVMWarpSize));
482 Value numberOfThreads =
483 LLVM::MulOp::create(rewriter, loc, adaptor.getMemberCount(), warpSize);
484
485 Value namedBarrier =
486 LLVM::PoisonOp::create(rewriter, loc, namedBarrierType);
487 DenseI64ArrayAttr barrierIdPos = rewriter.getDenseI64ArrayAttr({0});
488 DenseI64ArrayAttr numberOfThreadsPos = rewriter.getDenseI64ArrayAttr({1});
489 namedBarrier = LLVM::InsertValueOp::create(rewriter, loc, namedBarrier,
490 barrierId, barrierIdPos);
491 namedBarrier = LLVM::InsertValueOp::create(
492 rewriter, loc, namedBarrier, numberOfThreads, numberOfThreadsPos);
493 rewriter.replaceOp(op, namedBarrier);
494 return success();
495 }
496};
497
498/// A pass that replaces all occurrences of GPU device operations with their
499/// corresponding NVVM equivalent.
500///
501/// This pass only handles device code and is not meant to be run on GPU host
502/// code.
503struct LowerGpuOpsToNVVMOpsPass final
504 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
505 using Base::Base;
506
507 void getDependentDialects(DialectRegistry &registry) const override {
508 Base::getDependentDialects(registry);
510 }
511
512 void runOnOperation() override {
513 gpu::GPUModuleOp m = getOperation();
514
515 // Request C wrapper emission.
516 for (auto func : m.getOps<func::FuncOp>()) {
517 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
518 UnitAttr::get(&getContext()));
519 }
520
521 // Customize the bitwidth used for the device side index computations.
522 LowerToLLVMOptions options(
523 m.getContext(),
524 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
525 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
526 options.overrideIndexBitwidth(indexBitwidth);
527 options.useBarePtrCallConv = useBarePtrCallConv;
528
529 // Apply in-dialect lowering. In-dialect lowering will replace
530 // ops which need to be lowered further, which is not supported by a
531 // single conversion pass.
532 {
533 RewritePatternSet patterns(m.getContext());
535 // Transform N-D vector.from_elements to 1-D vector.from_elements before
536 // conversion.
537 vector::populateVectorFromElementsUnrollPatterns(patterns);
538 if (failed(applyPatternsGreedily(m, std::move(patterns))))
539 return signalPassFailure();
540 }
541
542 LLVMTypeConverter converter(m.getContext(), options);
544 RewritePatternSet llvmPatterns(m.getContext());
545 LLVMConversionTarget target(getContext());
546
547 // Set higher benefit, so patterns will run before generic LLVM lowering.
548 populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
549 /*benefit=*/10);
550
551 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
552 allowedDialects.end());
553 for (Dialect *dialect : getContext().getLoadedDialects()) {
554 // Skip math patterns as nvvm needs custom math lowering.
555 if (isa<math::MathDialect>(dialect))
556 continue;
557
558 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
559 // Empty `allowedDialectsSet` means all dialects are allowed.
560 if (!allowedDialectsSet.empty() && !allowed)
561 continue;
562
563 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
564 if (!iface) {
565 // Error out if dialect was explicily specified but doesn't implement
566 // conversion interface.
567 if (allowed) {
568 m.emitError()
569 << "dialect does not implement ConvertToLLVMPatternInterface: "
570 << dialect->getNamespace();
571 return signalPassFailure();
572 }
573 continue;
574 }
575
576 iface->populateConvertToLLVMConversionPatterns(target, converter,
577 llvmPatterns);
578 }
579
580 populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
581 if (this->hasRedux)
582 populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
584 ConversionConfig config;
585 config.allowPatternRollback = allowPatternRollback;
586 if (failed(
587 applyPartialConversion(m, target, std::move(llvmPatterns), config)))
588 signalPassFailure();
589 }
590};
591
592} // namespace
593
595 target.addIllegalOp<func::FuncOp>();
596 target.addIllegalOp<cf::AssertOp>();
597 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
598 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
599 target.addIllegalDialect<gpu::GPUDialect>();
600 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
601 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
602 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
603 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
604 LLVM::SincosOp, LLVM::SqrtOp>();
605
606 // TODO: Remove once we support replacing non-root ops.
607 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
608}
609
612
613 converter.addConversion([&](gpu::NamedBarrierType type) -> Type {
614 Type i32 = IntegerType::get(type.getContext(), 32);
615 return LLVM::LLVMStructType::getLiteral(type.getContext(), {i32, i32});
616 });
617
618 // Lowering for MMAMatrixType.
619 converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
620 return convertMMAToLLVMType(type);
621 });
622}
623
625 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
626 PatternBenefit benefit) {
627 patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
628}
629
631 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
632 PatternBenefit benefit) {
635
636 patterns.add<GPUBarrierOpToNVVMLowering,
637 GPUInitializeNamedBarrierOpToNVVMLowering,
638 GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
639 converter, benefit);
640 patterns.add<
641 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
642 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
643 converter, IndexKind::Block, IntrType::Id, benefit);
644 patterns.add<
645 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
646 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
647 converter, IndexKind::Block, IntrType::Dim, benefit);
648 patterns.add<
649 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
650 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
651 converter, IndexKind::Other, IntrType::Id, benefit);
653 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
654 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
655 benefit);
657 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
658 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
659 converter, IndexKind::Cluster, IntrType::Id, benefit);
661 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
662 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
663 converter, IndexKind::Cluster, IntrType::Dim, benefit);
665 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
666 converter, IndexKind::Grid, IntrType::Id, benefit);
668 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
669 converter, IndexKind::Grid, IntrType::Dim, benefit);
670 patterns.add<GPULaneIdOpToNVVM, GPUBallotOpToNVVM, GPUShuffleOpLowering,
671 GPUReturnOpLowering>(converter, benefit);
672
674 converter, NVVM::kSharedMemoryAlignmentBit, benefit);
675
676 // Explicitly drop memory space when lowering private memory
677 // attributions since NVVM models it as `alloca`s in the default
678 // memory space and does not support `alloca`s with addrspace(5).
679 patterns.add<GPUFuncOpLowering>(
680 converter,
682 /*allocaAddrSpace=*/0,
683 /*workgroupAddrSpace=*/
684 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
685 StringAttr::get(&converter.getContext(),
686 NVVM::NVVMDialect::getKernelFuncAttrName()),
687 StringAttr::get(&converter.getContext(),
688 NVVM::NVVMDialect::getMaxntidAttrName()),
689 StringAttr::get(&converter.getContext(),
690 NVVM::NVVMDialect::getClusterDimAttrName())},
691 benefit);
692
693 populateLibDeviceConversionPatterns(converter, patterns, benefit);
694}
695
696//===----------------------------------------------------------------------===//
697// NVVMTargetAttr convert to LLVM attr interface
698//===----------------------------------------------------------------------===//
699
700namespace {
701struct NVVMTargetConvertToLLVMAttrInterface
702 : public ConvertToLLVMAttrInterface::ExternalModel<
703 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
704 /// Configure GPU to NVVM.
705 void populateConvertToLLVMConversionPatterns(
707 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
708};
709} // namespace
710
711void NVVMTargetConvertToLLVMAttrInterface::
712 populateConvertToLLVMConversionPatterns(Attribute attr,
714 LLVMTypeConverter &typeConverter,
715 RewritePatternSet &patterns) const {
717 configureGpuToNVVMTypeConverter(typeConverter);
718 populateGpuToNVVMConversionPatterns(typeConverter, patterns);
719}
720
722 registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
723 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
724 });
725}
return success()
b getContext())
constexpr int kWarpSize
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:273
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Block & front()
Definition Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:139
constexpr int kSharedMemoryAlignmentBit
Definition NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to NVVM libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.