MLIR 23.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
34
38#include <optional>
39
40namespace mlir {
41#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
42#include "mlir/Conversion/Passes.h.inc"
43} // namespace mlir
44
45using namespace mlir;
46
47namespace {
48
49/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
50static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
51 switch (mode) {
52 case gpu::ShuffleMode::XOR:
53 return NVVM::ShflKind::bfly;
54 case gpu::ShuffleMode::UP:
55 return NVVM::ShflKind::up;
56 case gpu::ShuffleMode::DOWN:
57 return NVVM::ShflKind::down;
58 case gpu::ShuffleMode::IDX:
59 return NVVM::ShflKind::idx;
60 }
61 llvm_unreachable("unknown shuffle mode");
62}
63
64static std::optional<NVVM::ReductionKind>
65convertToNVVMReductionKind(gpu::AllReduceOperation mode) {
66 switch (mode) {
67 case gpu::AllReduceOperation::ADD:
68 return NVVM::ReductionKind::ADD;
69 case gpu::AllReduceOperation::MUL:
70 return std::nullopt;
71 case gpu::AllReduceOperation::MINSI:
72 return NVVM::ReductionKind::MIN;
73 case gpu::AllReduceOperation::MINUI:
74 return std::nullopt;
75 case gpu::AllReduceOperation::MINNUMF:
76 return NVVM::ReductionKind::MIN;
77 case gpu::AllReduceOperation::MAXSI:
78 return NVVM::ReductionKind::MAX;
79 case gpu::AllReduceOperation::MAXUI:
80 return std::nullopt;
81 case gpu::AllReduceOperation::MAXNUMF:
82 return NVVM::ReductionKind::MAX;
83 case gpu::AllReduceOperation::AND:
84 return NVVM::ReductionKind::AND;
85 case gpu::AllReduceOperation::OR:
86 return NVVM::ReductionKind::OR;
87 case gpu::AllReduceOperation::XOR:
88 return NVVM::ReductionKind::XOR;
89 case gpu::AllReduceOperation::MINIMUMF:
90 case gpu::AllReduceOperation::MAXIMUMF:
91 return std::nullopt;
92 }
93 return std::nullopt;
94}
95
96/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
97/// must be run by the entire subgroup, otherwise it is undefined behaviour.
98struct GPUSubgroupReduceOpLowering
99 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
100 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
101 LogicalResult
102
103 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter) const override {
105 if (op.getClusterSize())
106 return rewriter.notifyMatchFailure(
107 op, "lowering for clustered reduce not implemented");
108
109 if (!op.getUniform())
110 return rewriter.notifyMatchFailure(
111 op, "cannot be lowered to redux as the op must be run "
112 "uniformly (entire subgroup).");
113 if (!op.getValue().getType().isInteger(32))
114 return rewriter.notifyMatchFailure(op, "unsupported data type");
115
116 std::optional<NVVM::ReductionKind> mode =
117 convertToNVVMReductionKind(op.getOp());
118 if (!mode.has_value())
119 return rewriter.notifyMatchFailure(
120 op, "unsupported reduction mode for redux");
121
122 Location loc = op->getLoc();
123 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
124 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
125
126 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
127 op.getValue(), mode.value(), offset);
128
129 rewriter.replaceOp(op, reduxOp->getResult(0));
130 return success();
131 }
132};
133
134struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
135 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
136
137 /// Lowers a shuffle to the corresponding NVVM op.
138 ///
139 /// Convert the `width` argument into an activeMask (a bitmask which specifies
140 /// which threads participate in the shuffle) and a maskAndClamp (specifying
141 /// the highest lane which participates in the shuffle).
142 ///
143 /// %one = llvm.constant(1 : i32) : i32
144 /// %minus_one = llvm.constant(-1 : i32) : i32
145 /// %thirty_two = llvm.constant(32 : i32) : i32
146 /// %num_lanes = llvm.sub %thirty_two, %width : i32
147 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
148 /// %mask_and_clamp = llvm.sub %width, %one : i32
149 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
150 /// %mask_and_clamp : !llvm<"{ float, i1 }">
151 /// %shfl_value = llvm.extractvalue %shfl[0] :
152 /// !llvm<"{ float, i1 }">
153 /// %shfl_pred = llvm.extractvalue %shfl[1] :
154 /// !llvm<"{ float, i1 }">
155 LogicalResult
156 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter) const override {
158 Location loc = op->getLoc();
159
160 auto valueTy = adaptor.getValue().getType();
161 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
162 auto predTy = IntegerType::get(rewriter.getContext(), 1);
163
164 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
165 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
166 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
167 Value numLeadInactiveLane = LLVM::SubOp::create(
168 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
169 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
170 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
171 numLeadInactiveLane);
172 Value maskAndClamp;
173 if (op.getMode() == gpu::ShuffleMode::UP) {
174 // Clamp lane: `32 - activeWidth`
175 maskAndClamp = numLeadInactiveLane;
176 } else {
177 // Clamp lane: `activeWidth - 1`
178 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
179 adaptor.getWidth(), one);
180 }
181
182 bool predIsUsed = !op->getResult(1).use_empty();
183 UnitAttr returnValueAndIsValidAttr = nullptr;
184 Type resultTy = valueTy;
185 if (predIsUsed) {
186 returnValueAndIsValidAttr = rewriter.getUnitAttr();
187 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
188 {valueTy, predTy});
189 }
190 Value shfl = NVVM::ShflOp::create(
191 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
192 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
193 returnValueAndIsValidAttr);
194 if (predIsUsed) {
195 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
196 Value isActiveSrcLane =
197 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
198 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
199 } else {
200 rewriter.replaceOp(op, {shfl, nullptr});
201 }
202 return success();
203 }
204};
205
206struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
207 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
208
209 LogicalResult
210 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
211 ConversionPatternRewriter &rewriter) const override {
212 auto loc = op->getLoc();
213 MLIRContext *context = rewriter.getContext();
214 LLVM::ConstantRangeAttr bounds = nullptr;
215 if (std::optional<APInt> upperBound = op.getUpperBound())
216 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
217 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
218 else
219 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
220 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
221 Value newOp =
222 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
223 // Truncate or extend the result depending on the index bitwidth specified
224 // by the LLVMTypeConverter options.
225 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
226 if (indexBitwidth > 32) {
227 newOp = LLVM::SExtOp::create(
228 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
229 } else if (indexBitwidth < 32) {
230 newOp = LLVM::TruncOp::create(
231 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
232 }
233 rewriter.replaceOp(op, {newOp});
234 return success();
235 }
236};
237
238/// Lowering of cf.assert into a conditional __assertfail.
239struct AssertOpToAssertfailLowering
240 : public ConvertOpToLLVMPattern<cf::AssertOp> {
241 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
242
243 LogicalResult
244 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 MLIRContext *ctx = rewriter.getContext();
247 Location loc = assertOp.getLoc();
248 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
249 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
250 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
251 Type ptrType = LLVM::LLVMPointerType::get(ctx);
252 Type voidType = LLVM::LLVMVoidType::get(ctx);
253
254 // Find or create __assertfail function declaration.
255 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
256 auto assertfailType = LLVM::LLVMFunctionType::get(
257 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
258 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
259 moduleOp, loc, rewriter, "__assertfail", assertfailType);
260 assertfailDecl.setPassthroughAttr(
261 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
262
263 // Split blocks and insert conditional branch.
264 // ^before:
265 // ...
266 // cf.cond_br %condition, ^after, ^assert
267 // ^assert:
268 // cf.assert
269 // cf.br ^after
270 // ^after:
271 // ...
272 Block *beforeBlock = assertOp->getBlock();
273 Block *assertBlock =
274 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
275 Block *afterBlock =
276 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
277 rewriter.setInsertionPointToEnd(beforeBlock);
278 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
279 assertBlock);
280 rewriter.setInsertionPointToEnd(assertBlock);
281 cf::BranchOp::create(rewriter, loc, afterBlock);
282
283 // Continue cf.assert lowering.
284 rewriter.setInsertionPoint(assertOp);
285
286 // Populate file name, file number and function name from the location of
287 // the AssertOp.
288 StringRef fileName = "(unknown)";
289 StringRef funcName = "(unknown)";
290 int32_t fileLine = 0;
291 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
292 loc = callSiteLoc.getCallee();
293 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
294 fileName = fileLineColLoc.getFilename().strref();
295 fileLine = fileLineColLoc.getStartLine();
296 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
297 funcName = nameLoc.getName().strref();
298 if (auto fileLineColLoc =
299 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
300 fileName = fileLineColLoc.getFilename().strref();
301 fileLine = fileLineColLoc.getStartLine();
302 }
303 }
304
305 // Create constants.
306 auto getGlobal = [&](LLVM::GlobalOp global) {
307 // Get a pointer to the format string's first element.
308 Value globalPtr = LLVM::AddressOfOp::create(
309 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
310 global.getSymNameAttr());
311 Value start =
312 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
313 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
314 return start;
315 };
316 Value assertMessage = getGlobal(getOrCreateStringConstant(
317 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
318 Value assertFile = getGlobal(getOrCreateStringConstant(
319 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
320 Value assertFunc = getGlobal(getOrCreateStringConstant(
321 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
322 Value assertLine =
323 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
324 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
325
326 // Insert function call to __assertfail.
327 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
328 assertFunc, c1};
329 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
330 arguments);
331 return success();
332 }
333};
334
335/// Import the GPU Ops to NVVM Patterns.
336#include "GPUToNVVM.cpp.inc"
337
338/// A pass that replaces all occurrences of GPU device operations with their
339/// corresponding NVVM equivalent.
340///
341/// This pass only handles device code and is not meant to be run on GPU host
342/// code.
343struct LowerGpuOpsToNVVMOpsPass final
344 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
345 using Base::Base;
346
347 void getDependentDialects(DialectRegistry &registry) const override {
348 Base::getDependentDialects(registry);
350 }
351
352 void runOnOperation() override {
353 gpu::GPUModuleOp m = getOperation();
354
355 // Request C wrapper emission.
356 for (auto func : m.getOps<func::FuncOp>()) {
357 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
358 UnitAttr::get(&getContext()));
359 }
360
361 // Customize the bitwidth used for the device side index computations.
362 LowerToLLVMOptions options(
363 m.getContext(),
364 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
365 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
366 options.overrideIndexBitwidth(indexBitwidth);
367 options.useBarePtrCallConv = useBarePtrCallConv;
368
369 // Apply in-dialect lowering. In-dialect lowering will replace
370 // ops which need to be lowered further, which is not supported by a
371 // single conversion pass.
372 {
373 RewritePatternSet patterns(m.getContext());
375 // Transform N-D vector.from_elements to 1-D vector.from_elements before
376 // conversion.
377 vector::populateVectorFromElementsUnrollPatterns(patterns);
378 if (failed(applyPatternsGreedily(m, std::move(patterns))))
379 return signalPassFailure();
380 }
381
382 LLVMTypeConverter converter(m.getContext(), options);
384 RewritePatternSet llvmPatterns(m.getContext());
385 LLVMConversionTarget target(getContext());
386
387 // Set higher benefit, so patterns will run before generic LLVM lowering.
388 populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
389 /*benefit=*/10);
390
391 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
392 allowedDialects.end());
393 for (Dialect *dialect : getContext().getLoadedDialects()) {
394 // Skip math patterns as nvvm needs custom math lowering.
395 if (isa<math::MathDialect>(dialect))
396 continue;
397
398 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
399 // Empty `allowedDialectsSet` means all dialects are allowed.
400 if (!allowedDialectsSet.empty() && !allowed)
401 continue;
402
403 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
404 if (!iface) {
405 // Error out if dialect was explicily specified but doesn't implement
406 // conversion interface.
407 if (allowed) {
408 m.emitError()
409 << "dialect does not implement ConvertToLLVMPatternInterface: "
410 << dialect->getNamespace();
411 return signalPassFailure();
412 }
413 continue;
414 }
415
416 iface->populateConvertToLLVMConversionPatterns(target, converter,
417 llvmPatterns);
418 }
419
420 populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
421 if (this->hasRedux)
422 populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
424 ConversionConfig config;
425 config.allowPatternRollback = allowPatternRollback;
426 if (failed(
427 applyPartialConversion(m, target, std::move(llvmPatterns), config)))
428 signalPassFailure();
429 }
430};
431
432} // namespace
433
435 target.addIllegalOp<func::FuncOp>();
436 target.addIllegalOp<cf::AssertOp>();
437 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
438 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
439 target.addIllegalDialect<gpu::GPUDialect>();
440 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
441 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
442 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
443 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
444 LLVM::SincosOp, LLVM::SqrtOp>();
445
446 // TODO: Remove once we support replacing non-root ops.
447 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
448}
449
452
453 // Lowering for MMAMatrixType.
454 converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
455 return convertMMAToLLVMType(type);
456 });
457}
458
459struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
461
462 LogicalResult
463 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
464 ConversionPatternRewriter &rewriter) const override {
465 Location loc = op.getLoc();
466 Value input = adaptor.getOperand();
467 Type inputType = input.getType();
468 auto convertedInput = maybeExt(input, rewriter);
469 auto computeType = convertedInput.getType();
470
471 StringRef sincosFunc;
472 if (isa<Float32Type>(computeType)) {
473 const arith::FastMathFlags flag = op.getFastmath();
474 const bool useApprox =
475 mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
476 sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
477 } else if (isa<Float64Type>(computeType)) {
478 sincosFunc = "__nv_sincos";
479 } else {
480 return rewriter.notifyMatchFailure(op,
481 "unsupported operand type for sincos");
482 }
483
484 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
485
486 Value sinPtr, cosPtr;
487 {
488 OpBuilder::InsertionGuard guard(rewriter);
489 auto *scope =
490 op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
491 assert(scope && "Expected op to be inside automatic allocation scope");
492 rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
493 auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
494 rewriter.getI32IntegerAttr(1));
495 sinPtr =
496 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
497 cosPtr =
498 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
499 }
500
501 createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
502 op);
503
504 auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
505 auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
506
507 rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
508 maybeTrunc(cosResult, inputType, rewriter)});
509 return success();
510 }
511
512private:
513 Value maybeExt(Value operand, PatternRewriter &rewriter) const {
514 if (isa<Float16Type, BFloat16Type>(operand.getType()))
515 return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
516 Float32Type::get(rewriter.getContext()),
517 operand);
518 return operand;
519 }
520
521 Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
522 if (operand.getType() != type)
523 return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
524 return operand;
525 }
526
527 void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
528 StringRef funcName, Value input, Value sinPtr,
529 Value cosPtr, Operation *op) const {
530 auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
531 auto ptrType = sinPtr.getType();
532
533 SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
534 auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
535
536 auto funcAttr = StringAttr::get(op->getContext(), funcName);
537 auto funcOp =
539
540 if (!funcOp) {
541 auto parentFunc = op->getParentOfType<FunctionOpInterface>();
542 assert(parentFunc && "expected there to be a parent function");
543 OpBuilder b(parentFunc);
544
545 auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
546 funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
547 }
548
549 SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
550 LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
551 }
552};
553
554template <typename OpTy>
555static void populateOpPatterns(const LLVMTypeConverter &converter,
557 PatternBenefit benefit, StringRef f32Func,
558 StringRef f64Func, StringRef f32ApproxFunc = "",
559 StringRef f16Func = "") {
560 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
561 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
562 f32ApproxFunc, f16Func,
563 /*i32Func=*/"", benefit);
564}
565
566template <typename OpTy>
567static void populateIntOpPatterns(const LLVMTypeConverter &converter,
569 PatternBenefit benefit, StringRef i32Func) {
570 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
571 patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
572 benefit);
573}
574
575template <typename OpTy>
578 PatternBenefit benefit,
579 StringRef f32Func, StringRef f64Func) {
580 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
581 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
582 /*i32Func=*/"", benefit);
583}
584
587 PatternBenefit benefit) {
588 patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
589}
590
593 PatternBenefit benefit) {
594 populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
595 "__nv_fmod");
597 "__nv_fmaxf", "__nv_fmax");
599 "__nv_fminf", "__nv_fmin");
600
601 populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
602 populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
603 "__nv_fabs");
604 populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
605 "__nv_acos");
606 populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
607 "__nv_acosh");
608 populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
609 "__nv_asin");
610 populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
611 "__nv_asinh");
612 populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
613 "__nv_atan");
614 populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
615 "__nv_atan2");
616 populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
617 "__nv_atanh");
618 populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
619 "__nv_cbrt");
620 populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
621 "__nv_ceil");
623 "__nv_copysignf", "__nv_copysign");
624 populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
625 "__nv_cos", "__nv_fast_cosf");
626 populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
627 "__nv_cosh");
628 populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
629 "__nv_erf");
630 populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
631 "__nv_erfc");
632 populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
633 "__nv_exp", "__nv_fast_expf");
634 populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
635 "__nv_exp2");
636 populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
637 "__nv_expm1");
638 populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
639 "__nv_floor");
640 populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
641 "__nv_fma");
642 // Note: libdevice uses a different name for 32-bit finite checking
644 "__nv_finitef", "__nv_isfinited");
645 populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
646 "__nv_isinfd");
647 populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
648 "__nv_isnand");
649 populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
650 "__nv_log", "__nv_fast_logf");
651 populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
652 "__nv_log10", "__nv_fast_log10f");
653 populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
654 "__nv_log1p");
655 populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
656 "__nv_log2", "__nv_fast_log2f");
657 populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
658 "__nv_pow", "__nv_fast_powf");
660 "__nv_powif", "__nv_powi");
661 populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
662 "__nv_round");
664 "__nv_rintf", "__nv_rint");
665 populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
666 "__nv_rsqrt");
667 populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
668 "__nv_sin", "__nv_fast_sinf");
669 populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
670 "__nv_sinh");
671 populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
672 "__nv_sqrt");
673 populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
674 "__nv_tan", "__nv_fast_tanf");
675 populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
676 "__nv_tanh");
677
678 // Custom pattern for sincos since it returns two values
679 patterns.add<SincosOpLowering>(converter, benefit);
680}
681
684 PatternBenefit benefit) {
687
688 // TODO: Pass benefit to generated patterns.
689 populateWithGenerated(patterns);
690
691 patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
692 converter, benefit);
693 patterns.add<
694 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
695 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
696 converter, IndexKind::Block, IntrType::Id, benefit);
697 patterns.add<
698 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
699 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
700 converter, IndexKind::Block, IntrType::Dim, benefit);
701 patterns.add<
702 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
703 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
704 converter, IndexKind::Other, IntrType::Id, benefit);
706 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
707 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
708 benefit);
710 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
711 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
712 converter, IndexKind::Cluster, IntrType::Id, benefit);
714 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
715 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
716 converter, IndexKind::Cluster, IntrType::Dim, benefit);
718 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
719 converter, IndexKind::Grid, IntrType::Id, benefit);
721 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
722 converter, IndexKind::Grid, IntrType::Dim, benefit);
723 patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
724 converter, benefit);
725
727 converter, NVVM::kSharedMemoryAlignmentBit, benefit);
728
729 // Explicitly drop memory space when lowering private memory
730 // attributions since NVVM models it as `alloca`s in the default
731 // memory space and does not support `alloca`s with addrspace(5).
733 converter,
735 /*allocaAddrSpace=*/0,
736 /*workgroupAddrSpace=*/
737 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
738 StringAttr::get(&converter.getContext(),
739 NVVM::NVVMDialect::getKernelFuncAttrName()),
740 StringAttr::get(&converter.getContext(),
741 NVVM::NVVMDialect::getMaxntidAttrName()),
742 StringAttr::get(&converter.getContext(),
743 NVVM::NVVMDialect::getClusterDimAttrName())},
744 benefit);
745
747}
748
749//===----------------------------------------------------------------------===//
750// NVVMTargetAttr convert to LLVM attr interface
751//===----------------------------------------------------------------------===//
752
753namespace {
754struct NVVMTargetConvertToLLVMAttrInterface
755 : public ConvertToLLVMAttrInterface::ExternalModel<
756 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
757 /// Configure GPU to NVVM.
758 void populateConvertToLLVMConversionPatterns(
760 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
761};
762} // namespace
763
764void NVVMTargetConvertToLLVMAttrInterface::
765 populateConvertToLLVMConversionPatterns(Attribute attr,
767 LLVMTypeConverter &typeConverter,
770 configureGpuToNVVMTypeConverter(typeConverter);
772}
773
775 registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
776 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
777 });
778}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
constexpr int kWarpSize
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
typename math::SincosOp::Adaptor OpAdaptor
Definition Pattern.h:218
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
Definition Location.h:60
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
A trait of region holding operations that define a new scope for automatic allocations,...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns that lower certain arith and math dialect ops to libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.