MLIR 22.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
34
38#include <optional>
39
40namespace mlir {
41#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
42#include "mlir/Conversion/Passes.h.inc"
43} // namespace mlir
44
45using namespace mlir;
46
47namespace {
48
49/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
50static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
51 switch (mode) {
52 case gpu::ShuffleMode::XOR:
53 return NVVM::ShflKind::bfly;
54 case gpu::ShuffleMode::UP:
55 return NVVM::ShflKind::up;
56 case gpu::ShuffleMode::DOWN:
57 return NVVM::ShflKind::down;
58 case gpu::ShuffleMode::IDX:
59 return NVVM::ShflKind::idx;
60 }
61 llvm_unreachable("unknown shuffle mode");
62}
63
64static std::optional<NVVM::ReduxKind>
65convertReduxKind(gpu::AllReduceOperation mode) {
66 switch (mode) {
67 case gpu::AllReduceOperation::ADD:
68 return NVVM::ReduxKind::ADD;
69 case gpu::AllReduceOperation::MUL:
70 return std::nullopt;
71 case gpu::AllReduceOperation::MINSI:
72 return NVVM::ReduxKind::MIN;
73 case gpu::AllReduceOperation::MINUI:
74 return std::nullopt;
75 case gpu::AllReduceOperation::MINNUMF:
76 return NVVM::ReduxKind::MIN;
77 case gpu::AllReduceOperation::MAXSI:
78 return NVVM::ReduxKind::MAX;
79 case gpu::AllReduceOperation::MAXUI:
80 return std::nullopt;
81 case gpu::AllReduceOperation::MAXNUMF:
82 return NVVM::ReduxKind::MAX;
83 case gpu::AllReduceOperation::AND:
84 return NVVM::ReduxKind::AND;
85 case gpu::AllReduceOperation::OR:
86 return NVVM::ReduxKind::OR;
87 case gpu::AllReduceOperation::XOR:
88 return NVVM::ReduxKind::XOR;
89 case gpu::AllReduceOperation::MINIMUMF:
90 case gpu::AllReduceOperation::MAXIMUMF:
91 return std::nullopt;
92 }
93 return std::nullopt;
94}
95
96/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
97/// must be run by the entire subgroup, otherwise it is undefined behaviour.
98struct GPUSubgroupReduceOpLowering
99 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
100 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
101 LogicalResult
102
103 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter) const override {
105 if (op.getClusterSize())
106 return rewriter.notifyMatchFailure(
107 op, "lowering for clustered reduce not implemented");
108
109 if (!op.getUniform())
110 return rewriter.notifyMatchFailure(
111 op, "cannot be lowered to redux as the op must be run "
112 "uniformly (entire subgroup).");
113 if (!op.getValue().getType().isInteger(32))
114 return rewriter.notifyMatchFailure(op, "unsupported data type");
115
116 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
117 if (!mode.has_value())
118 return rewriter.notifyMatchFailure(
119 op, "unsupported reduction mode for redux");
120
121 Location loc = op->getLoc();
122 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
123 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
124
125 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
126 op.getValue(), mode.value(), offset);
127
128 rewriter.replaceOp(op, reduxOp->getResult(0));
129 return success();
130 }
131};
132
133struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
134 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
135
136 /// Lowers a shuffle to the corresponding NVVM op.
137 ///
138 /// Convert the `width` argument into an activeMask (a bitmask which specifies
139 /// which threads participate in the shuffle) and a maskAndClamp (specifying
140 /// the highest lane which participates in the shuffle).
141 ///
142 /// %one = llvm.constant(1 : i32) : i32
143 /// %minus_one = llvm.constant(-1 : i32) : i32
144 /// %thirty_two = llvm.constant(32 : i32) : i32
145 /// %num_lanes = llvm.sub %thirty_two, %width : i32
146 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
147 /// %mask_and_clamp = llvm.sub %width, %one : i32
148 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
149 /// %mask_and_clamp : !llvm<"{ float, i1 }">
150 /// %shfl_value = llvm.extractvalue %shfl[0] :
151 /// !llvm<"{ float, i1 }">
152 /// %shfl_pred = llvm.extractvalue %shfl[1] :
153 /// !llvm<"{ float, i1 }">
154 LogicalResult
155 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
156 ConversionPatternRewriter &rewriter) const override {
157 Location loc = op->getLoc();
158
159 auto valueTy = adaptor.getValue().getType();
160 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
161 auto predTy = IntegerType::get(rewriter.getContext(), 1);
162
163 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
164 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
165 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
166 Value numLeadInactiveLane = LLVM::SubOp::create(
167 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
168 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
169 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
170 numLeadInactiveLane);
171 Value maskAndClamp;
172 if (op.getMode() == gpu::ShuffleMode::UP) {
173 // Clamp lane: `32 - activeWidth`
174 maskAndClamp = numLeadInactiveLane;
175 } else {
176 // Clamp lane: `activeWidth - 1`
177 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
178 adaptor.getWidth(), one);
179 }
180
181 bool predIsUsed = !op->getResult(1).use_empty();
182 UnitAttr returnValueAndIsValidAttr = nullptr;
183 Type resultTy = valueTy;
184 if (predIsUsed) {
185 returnValueAndIsValidAttr = rewriter.getUnitAttr();
186 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
187 {valueTy, predTy});
188 }
189 Value shfl = NVVM::ShflOp::create(
190 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
191 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
192 returnValueAndIsValidAttr);
193 if (predIsUsed) {
194 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
195 Value isActiveSrcLane =
196 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
197 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
198 } else {
199 rewriter.replaceOp(op, {shfl, nullptr});
200 }
201 return success();
202 }
203};
204
205struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
206 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
207
208 LogicalResult
209 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
210 ConversionPatternRewriter &rewriter) const override {
211 auto loc = op->getLoc();
212 MLIRContext *context = rewriter.getContext();
213 LLVM::ConstantRangeAttr bounds = nullptr;
214 if (std::optional<APInt> upperBound = op.getUpperBound())
215 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
216 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
217 else
218 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
219 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
220 Value newOp =
221 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
222 // Truncate or extend the result depending on the index bitwidth specified
223 // by the LLVMTypeConverter options.
224 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
225 if (indexBitwidth > 32) {
226 newOp = LLVM::SExtOp::create(
227 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
228 } else if (indexBitwidth < 32) {
229 newOp = LLVM::TruncOp::create(
230 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
231 }
232 rewriter.replaceOp(op, {newOp});
233 return success();
234 }
235};
236
237/// Lowering of cf.assert into a conditional __assertfail.
238struct AssertOpToAssertfailLowering
239 : public ConvertOpToLLVMPattern<cf::AssertOp> {
240 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
241
242 LogicalResult
243 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
244 ConversionPatternRewriter &rewriter) const override {
245 MLIRContext *ctx = rewriter.getContext();
246 Location loc = assertOp.getLoc();
247 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
248 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
249 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
250 Type ptrType = LLVM::LLVMPointerType::get(ctx);
251 Type voidType = LLVM::LLVMVoidType::get(ctx);
252
253 // Find or create __assertfail function declaration.
254 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
255 auto assertfailType = LLVM::LLVMFunctionType::get(
256 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
257 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
258 moduleOp, loc, rewriter, "__assertfail", assertfailType);
259 assertfailDecl.setPassthroughAttr(
260 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
261
262 // Split blocks and insert conditional branch.
263 // ^before:
264 // ...
265 // cf.cond_br %condition, ^after, ^assert
266 // ^assert:
267 // cf.assert
268 // cf.br ^after
269 // ^after:
270 // ...
271 Block *beforeBlock = assertOp->getBlock();
272 Block *assertBlock =
273 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
274 Block *afterBlock =
275 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
276 rewriter.setInsertionPointToEnd(beforeBlock);
277 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
278 assertBlock);
279 rewriter.setInsertionPointToEnd(assertBlock);
280 cf::BranchOp::create(rewriter, loc, afterBlock);
281
282 // Continue cf.assert lowering.
283 rewriter.setInsertionPoint(assertOp);
284
285 // Populate file name, file number and function name from the location of
286 // the AssertOp.
287 StringRef fileName = "(unknown)";
288 StringRef funcName = "(unknown)";
289 int32_t fileLine = 0;
290 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
291 loc = callSiteLoc.getCallee();
292 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
293 fileName = fileLineColLoc.getFilename().strref();
294 fileLine = fileLineColLoc.getStartLine();
295 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
296 funcName = nameLoc.getName().strref();
297 if (auto fileLineColLoc =
298 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
299 fileName = fileLineColLoc.getFilename().strref();
300 fileLine = fileLineColLoc.getStartLine();
301 }
302 }
303
304 // Create constants.
305 auto getGlobal = [&](LLVM::GlobalOp global) {
306 // Get a pointer to the format string's first element.
307 Value globalPtr = LLVM::AddressOfOp::create(
308 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
309 global.getSymNameAttr());
310 Value start =
311 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
312 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
313 return start;
314 };
315 Value assertMessage = getGlobal(getOrCreateStringConstant(
316 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
317 Value assertFile = getGlobal(getOrCreateStringConstant(
318 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
319 Value assertFunc = getGlobal(getOrCreateStringConstant(
320 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
321 Value assertLine =
322 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
323 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
324
325 // Insert function call to __assertfail.
326 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
327 assertFunc, c1};
328 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
329 arguments);
330 return success();
331 }
332};
333
334/// Import the GPU Ops to NVVM Patterns.
335#include "GPUToNVVM.cpp.inc"
336
337/// A pass that replaces all occurrences of GPU device operations with their
338/// corresponding NVVM equivalent.
339///
340/// This pass only handles device code and is not meant to be run on GPU host
341/// code.
342struct LowerGpuOpsToNVVMOpsPass final
343 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
344 using Base::Base;
345
346 void getDependentDialects(DialectRegistry &registry) const override {
347 Base::getDependentDialects(registry);
349 }
350
351 void runOnOperation() override {
352 gpu::GPUModuleOp m = getOperation();
353
354 // Request C wrapper emission.
355 for (auto func : m.getOps<func::FuncOp>()) {
356 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
357 UnitAttr::get(&getContext()));
358 }
359
360 // Customize the bitwidth used for the device side index computations.
361 LowerToLLVMOptions options(
362 m.getContext(),
363 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
364 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
365 options.overrideIndexBitwidth(indexBitwidth);
366 options.useBarePtrCallConv = useBarePtrCallConv;
367
368 // Apply in-dialect lowering. In-dialect lowering will replace
369 // ops which need to be lowered further, which is not supported by a
370 // single conversion pass.
371 {
372 RewritePatternSet patterns(m.getContext());
374 // Transform N-D vector.from_elements to 1-D vector.from_elements before
375 // conversion.
376 vector::populateVectorFromElementsUnrollPatterns(patterns);
377 if (failed(applyPatternsGreedily(m, std::move(patterns))))
378 return signalPassFailure();
379 }
380
381 LLVMTypeConverter converter(m.getContext(), options);
383 RewritePatternSet llvmPatterns(m.getContext());
384 LLVMConversionTarget target(getContext());
385
386 // Set higher benefit, so patterns will run before generic LLVM lowering.
387 populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
388 /*benefit=*/10);
389
390 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
391 allowedDialects.end());
392 for (Dialect *dialect : getContext().getLoadedDialects()) {
393 // Skip math patterns as nvvm needs custom math lowering.
394 if (isa<math::MathDialect>(dialect))
395 continue;
396
397 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
398 // Empty `allowedDialectsSet` means all dialects are allowed.
399 if (!allowedDialectsSet.empty() && !allowed)
400 continue;
401
402 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
403 if (!iface) {
404 // Error out if dialect was explicily specified but doesn't implement
405 // conversion interface.
406 if (allowed) {
407 m.emitError()
408 << "dialect does not implement ConvertToLLVMPatternInterface: "
409 << dialect->getNamespace();
410 return signalPassFailure();
411 }
412 continue;
413 }
414
415 iface->populateConvertToLLVMConversionPatterns(target, converter,
416 llvmPatterns);
417 }
418
419 populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
420 if (this->hasRedux)
421 populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
423 ConversionConfig config;
424 config.allowPatternRollback = allowPatternRollback;
425 if (failed(
426 applyPartialConversion(m, target, std::move(llvmPatterns), config)))
427 signalPassFailure();
428 }
429};
430
431} // namespace
432
434 target.addIllegalOp<func::FuncOp>();
435 target.addIllegalOp<cf::AssertOp>();
436 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
437 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
438 target.addIllegalDialect<gpu::GPUDialect>();
439 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
440 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
441 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
442 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
443 LLVM::SincosOp, LLVM::SqrtOp>();
444
445 // TODO: Remove once we support replacing non-root ops.
446 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
447}
448
451
452 // Lowering for MMAMatrixType.
453 converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
454 return convertMMAToLLVMType(type);
455 });
456}
457
458struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
460
461 LogicalResult
462 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
463 ConversionPatternRewriter &rewriter) const override {
464 Location loc = op.getLoc();
465 Value input = adaptor.getOperand();
466 Type inputType = input.getType();
467 auto convertedInput = maybeExt(input, rewriter);
468 auto computeType = convertedInput.getType();
469
470 StringRef sincosFunc;
471 if (isa<Float32Type>(computeType)) {
472 const arith::FastMathFlags flag = op.getFastmath();
473 const bool useApprox =
474 mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
475 sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
476 } else if (isa<Float64Type>(computeType)) {
477 sincosFunc = "__nv_sincos";
478 } else {
479 return rewriter.notifyMatchFailure(op,
480 "unsupported operand type for sincos");
481 }
482
483 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
484
485 Value sinPtr, cosPtr;
486 {
487 OpBuilder::InsertionGuard guard(rewriter);
488 auto *scope =
489 op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
490 assert(scope && "Expected op to be inside automatic allocation scope");
491 rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
492 auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
493 rewriter.getI32IntegerAttr(1));
494 sinPtr =
495 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
496 cosPtr =
497 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
498 }
499
500 createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
501 op);
502
503 auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
504 auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
505
506 rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
507 maybeTrunc(cosResult, inputType, rewriter)});
508 return success();
509 }
510
511private:
512 Value maybeExt(Value operand, PatternRewriter &rewriter) const {
513 if (isa<Float16Type, BFloat16Type>(operand.getType()))
514 return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
515 Float32Type::get(rewriter.getContext()),
516 operand);
517 return operand;
518 }
519
520 Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
521 if (operand.getType() != type)
522 return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
523 return operand;
524 }
525
526 void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
527 StringRef funcName, Value input, Value sinPtr,
528 Value cosPtr, Operation *op) const {
529 auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
530 auto ptrType = sinPtr.getType();
531
532 SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
533 auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
534
535 auto funcAttr = StringAttr::get(op->getContext(), funcName);
536 auto funcOp =
538
539 if (!funcOp) {
540 auto parentFunc = op->getParentOfType<FunctionOpInterface>();
541 assert(parentFunc && "expected there to be a parent function");
542 OpBuilder b(parentFunc);
543
544 auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
545 funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
546 }
547
548 SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
549 LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
550 }
551};
552
553template <typename OpTy>
554static void populateOpPatterns(const LLVMTypeConverter &converter,
556 PatternBenefit benefit, StringRef f32Func,
557 StringRef f64Func, StringRef f32ApproxFunc = "",
558 StringRef f16Func = "") {
559 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
560 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
561 f32ApproxFunc, f16Func,
562 /*i32Func=*/"", benefit);
563}
564
565template <typename OpTy>
566static void populateIntOpPatterns(const LLVMTypeConverter &converter,
568 PatternBenefit benefit, StringRef i32Func) {
569 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
570 patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
571 benefit);
572}
573
574template <typename OpTy>
577 PatternBenefit benefit,
578 StringRef f32Func, StringRef f64Func) {
579 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
580 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
581 /*i32Func=*/"", benefit);
582}
583
586 PatternBenefit benefit) {
587 patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
588}
589
592 PatternBenefit benefit) {
593 populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
594 "__nv_fmod");
596 "__nv_fmaxf", "__nv_fmax");
598 "__nv_fminf", "__nv_fmin");
599
600 populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
601 populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
602 "__nv_fabs");
603 populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
604 "__nv_acos");
605 populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
606 "__nv_acosh");
607 populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
608 "__nv_asin");
609 populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
610 "__nv_asinh");
611 populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
612 "__nv_atan");
613 populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
614 "__nv_atan2");
615 populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
616 "__nv_atanh");
617 populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
618 "__nv_cbrt");
619 populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
620 "__nv_ceil");
622 "__nv_copysignf", "__nv_copysign");
623 populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
624 "__nv_cos", "__nv_fast_cosf");
625 populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
626 "__nv_cosh");
627 populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
628 "__nv_erf");
629 populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
630 "__nv_erfc");
631 populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
632 "__nv_exp", "__nv_fast_expf");
633 populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
634 "__nv_exp2");
635 populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
636 "__nv_expm1");
637 populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
638 "__nv_floor");
639 populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
640 "__nv_fma");
641 // Note: libdevice uses a different name for 32-bit finite checking
643 "__nv_finitef", "__nv_isfinited");
644 populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
645 "__nv_isinfd");
646 populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
647 "__nv_isnand");
648 populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
649 "__nv_log", "__nv_fast_logf");
650 populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
651 "__nv_log10", "__nv_fast_log10f");
652 populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
653 "__nv_log1p");
654 populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
655 "__nv_log2", "__nv_fast_log2f");
656 populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
657 "__nv_pow", "__nv_fast_powf");
659 "__nv_powif", "__nv_powi");
660 populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
661 "__nv_round");
663 "__nv_rintf", "__nv_rint");
664 populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
665 "__nv_rsqrt");
666 populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
667 "__nv_sin", "__nv_fast_sinf");
668 populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
669 "__nv_sinh");
670 populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
671 "__nv_sqrt");
672 populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
673 "__nv_tan", "__nv_fast_tanf");
674 populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
675 "__nv_tanh");
676
677 // Custom pattern for sincos since it returns two values
678 patterns.add<SincosOpLowering>(converter, benefit);
679}
680
683 PatternBenefit benefit) {
686
687 // TODO: Pass benefit to generated patterns.
688 populateWithGenerated(patterns);
689
690 patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
691 converter, benefit);
692 patterns.add<
693 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
694 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
695 converter, IndexKind::Block, IntrType::Id, benefit);
696 patterns.add<
697 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
698 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
699 converter, IndexKind::Block, IntrType::Dim, benefit);
700 patterns.add<
701 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
702 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
703 converter, IndexKind::Other, IntrType::Id, benefit);
705 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
706 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
707 benefit);
709 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
710 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
711 converter, IndexKind::Other, IntrType::Id, benefit);
713 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
714 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
715 converter, IndexKind::Other, IntrType::Dim, benefit);
717 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
718 converter, IndexKind::Grid, IntrType::Id, benefit);
720 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
721 converter, IndexKind::Grid, IntrType::Dim, benefit);
722 patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
723 converter, benefit);
724
726 converter, NVVM::kSharedMemoryAlignmentBit, benefit);
727
728 // Explicitly drop memory space when lowering private memory
729 // attributions since NVVM models it as `alloca`s in the default
730 // memory space and does not support `alloca`s with addrspace(5).
732 converter,
734 /*allocaAddrSpace=*/0,
735 /*workgroupAddrSpace=*/
736 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
737 StringAttr::get(&converter.getContext(),
738 NVVM::NVVMDialect::getKernelFuncAttrName()),
739 StringAttr::get(&converter.getContext(),
740 NVVM::NVVMDialect::getMaxntidAttrName())},
741 benefit);
742
744}
745
746//===----------------------------------------------------------------------===//
747// NVVMTargetAttr convert to LLVM attr interface
748//===----------------------------------------------------------------------===//
749
750namespace {
751struct NVVMTargetConvertToLLVMAttrInterface
752 : public ConvertToLLVMAttrInterface::ExternalModel<
753 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
754 /// Configure GPU to NVVM.
755 void populateConvertToLLVMConversionPatterns(
757 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
758};
759} // namespace
760
761void NVVMTargetConvertToLLVMAttrInterface::
762 populateConvertToLLVMConversionPatterns(Attribute attr,
764 LLVMTypeConverter &typeConverter,
767 configureGpuToNVVMTypeConverter(typeConverter);
769}
770
772 registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
773 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
774 });
775}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
constexpr int kWarpSize
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
typename math::SincosOp::Adaptor OpAdaptor
Definition Pattern.h:218
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
Definition Location.h:60
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
A trait of region holding operations that define a new scope for automatic allocations,...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns that lower certain arith and math dialect ops to libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.