MLIR 22.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
33
37#include <optional>
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45
46namespace {
47
48/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
49static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
50 switch (mode) {
51 case gpu::ShuffleMode::XOR:
52 return NVVM::ShflKind::bfly;
53 case gpu::ShuffleMode::UP:
54 return NVVM::ShflKind::up;
55 case gpu::ShuffleMode::DOWN:
56 return NVVM::ShflKind::down;
57 case gpu::ShuffleMode::IDX:
58 return NVVM::ShflKind::idx;
59 }
60 llvm_unreachable("unknown shuffle mode");
61}
62
63static std::optional<NVVM::ReduxKind>
64convertReduxKind(gpu::AllReduceOperation mode) {
65 switch (mode) {
66 case gpu::AllReduceOperation::ADD:
67 return NVVM::ReduxKind::ADD;
68 case gpu::AllReduceOperation::MUL:
69 return std::nullopt;
70 case gpu::AllReduceOperation::MINSI:
71 return NVVM::ReduxKind::MIN;
72 case gpu::AllReduceOperation::MINUI:
73 return std::nullopt;
74 case gpu::AllReduceOperation::MINNUMF:
75 return NVVM::ReduxKind::MIN;
76 case gpu::AllReduceOperation::MAXSI:
77 return NVVM::ReduxKind::MAX;
78 case gpu::AllReduceOperation::MAXUI:
79 return std::nullopt;
80 case gpu::AllReduceOperation::MAXNUMF:
81 return NVVM::ReduxKind::MAX;
82 case gpu::AllReduceOperation::AND:
83 return NVVM::ReduxKind::AND;
84 case gpu::AllReduceOperation::OR:
85 return NVVM::ReduxKind::OR;
86 case gpu::AllReduceOperation::XOR:
87 return NVVM::ReduxKind::XOR;
88 case gpu::AllReduceOperation::MINIMUMF:
89 case gpu::AllReduceOperation::MAXIMUMF:
90 return std::nullopt;
91 }
92 return std::nullopt;
93}
94
95/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
96/// must be run by the entire subgroup, otherwise it is undefined behaviour.
97struct GPUSubgroupReduceOpLowering
98 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
99 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
100 LogicalResult
101
102 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter) const override {
104 if (op.getClusterSize())
105 return rewriter.notifyMatchFailure(
106 op, "lowering for clustered reduce not implemented");
107
108 if (!op.getUniform())
109 return rewriter.notifyMatchFailure(
110 op, "cannot be lowered to redux as the op must be run "
111 "uniformly (entire subgroup).");
112 if (!op.getValue().getType().isInteger(32))
113 return rewriter.notifyMatchFailure(op, "unsupported data type");
114
115 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
116 if (!mode.has_value())
117 return rewriter.notifyMatchFailure(
118 op, "unsupported reduction mode for redux");
119
120 Location loc = op->getLoc();
121 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
122 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
123
124 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
125 op.getValue(), mode.value(), offset);
126
127 rewriter.replaceOp(op, reduxOp->getResult(0));
128 return success();
129 }
130};
131
132struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
133 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
134
135 /// Lowers a shuffle to the corresponding NVVM op.
136 ///
137 /// Convert the `width` argument into an activeMask (a bitmask which specifies
138 /// which threads participate in the shuffle) and a maskAndClamp (specifying
139 /// the highest lane which participates in the shuffle).
140 ///
141 /// %one = llvm.constant(1 : i32) : i32
142 /// %minus_one = llvm.constant(-1 : i32) : i32
143 /// %thirty_two = llvm.constant(32 : i32) : i32
144 /// %num_lanes = llvm.sub %thirty_two, %width : i32
145 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
146 /// %mask_and_clamp = llvm.sub %width, %one : i32
147 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
148 /// %mask_and_clamp : !llvm<"{ float, i1 }">
149 /// %shfl_value = llvm.extractvalue %shfl[0] :
150 /// !llvm<"{ float, i1 }">
151 /// %shfl_pred = llvm.extractvalue %shfl[1] :
152 /// !llvm<"{ float, i1 }">
153 LogicalResult
154 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter) const override {
156 Location loc = op->getLoc();
157
158 auto valueTy = adaptor.getValue().getType();
159 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
160 auto predTy = IntegerType::get(rewriter.getContext(), 1);
161
162 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
163 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
164 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
165 Value numLeadInactiveLane = LLVM::SubOp::create(
166 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
167 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
168 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
169 numLeadInactiveLane);
170 Value maskAndClamp;
171 if (op.getMode() == gpu::ShuffleMode::UP) {
172 // Clamp lane: `32 - activeWidth`
173 maskAndClamp = numLeadInactiveLane;
174 } else {
175 // Clamp lane: `activeWidth - 1`
176 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
177 adaptor.getWidth(), one);
178 }
179
180 bool predIsUsed = !op->getResult(1).use_empty();
181 UnitAttr returnValueAndIsValidAttr = nullptr;
182 Type resultTy = valueTy;
183 if (predIsUsed) {
184 returnValueAndIsValidAttr = rewriter.getUnitAttr();
185 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
186 {valueTy, predTy});
187 }
188 Value shfl = NVVM::ShflOp::create(
189 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
190 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
191 returnValueAndIsValidAttr);
192 if (predIsUsed) {
193 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
194 Value isActiveSrcLane =
195 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
196 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
197 } else {
198 rewriter.replaceOp(op, {shfl, nullptr});
199 }
200 return success();
201 }
202};
203
204struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
205 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
206
207 LogicalResult
208 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
209 ConversionPatternRewriter &rewriter) const override {
210 auto loc = op->getLoc();
211 MLIRContext *context = rewriter.getContext();
212 LLVM::ConstantRangeAttr bounds = nullptr;
213 if (std::optional<APInt> upperBound = op.getUpperBound())
214 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
215 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
216 else
217 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
218 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
219 Value newOp =
220 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
221 // Truncate or extend the result depending on the index bitwidth specified
222 // by the LLVMTypeConverter options.
223 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
224 if (indexBitwidth > 32) {
225 newOp = LLVM::SExtOp::create(
226 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
227 } else if (indexBitwidth < 32) {
228 newOp = LLVM::TruncOp::create(
229 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
230 }
231 rewriter.replaceOp(op, {newOp});
232 return success();
233 }
234};
235
236/// Lowering of cf.assert into a conditional __assertfail.
237struct AssertOpToAssertfailLowering
238 : public ConvertOpToLLVMPattern<cf::AssertOp> {
239 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
240
241 LogicalResult
242 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter) const override {
244 MLIRContext *ctx = rewriter.getContext();
245 Location loc = assertOp.getLoc();
246 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
247 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
248 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
249 Type ptrType = LLVM::LLVMPointerType::get(ctx);
250 Type voidType = LLVM::LLVMVoidType::get(ctx);
251
252 // Find or create __assertfail function declaration.
253 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
254 auto assertfailType = LLVM::LLVMFunctionType::get(
255 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
256 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
257 moduleOp, loc, rewriter, "__assertfail", assertfailType);
258 assertfailDecl.setPassthroughAttr(
259 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
260
261 // Split blocks and insert conditional branch.
262 // ^before:
263 // ...
264 // cf.cond_br %condition, ^after, ^assert
265 // ^assert:
266 // cf.assert
267 // cf.br ^after
268 // ^after:
269 // ...
270 Block *beforeBlock = assertOp->getBlock();
271 Block *assertBlock =
272 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
273 Block *afterBlock =
274 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
275 rewriter.setInsertionPointToEnd(beforeBlock);
276 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
277 assertBlock);
278 rewriter.setInsertionPointToEnd(assertBlock);
279 cf::BranchOp::create(rewriter, loc, afterBlock);
280
281 // Continue cf.assert lowering.
282 rewriter.setInsertionPoint(assertOp);
283
284 // Populate file name, file number and function name from the location of
285 // the AssertOp.
286 StringRef fileName = "(unknown)";
287 StringRef funcName = "(unknown)";
288 int32_t fileLine = 0;
289 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
290 loc = callSiteLoc.getCallee();
291 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
292 fileName = fileLineColLoc.getFilename().strref();
293 fileLine = fileLineColLoc.getStartLine();
294 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
295 funcName = nameLoc.getName().strref();
296 if (auto fileLineColLoc =
297 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
298 fileName = fileLineColLoc.getFilename().strref();
299 fileLine = fileLineColLoc.getStartLine();
300 }
301 }
302
303 // Create constants.
304 auto getGlobal = [&](LLVM::GlobalOp global) {
305 // Get a pointer to the format string's first element.
306 Value globalPtr = LLVM::AddressOfOp::create(
307 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
308 global.getSymNameAttr());
309 Value start =
310 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
311 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
312 return start;
313 };
314 Value assertMessage = getGlobal(getOrCreateStringConstant(
315 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
316 Value assertFile = getGlobal(getOrCreateStringConstant(
317 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
318 Value assertFunc = getGlobal(getOrCreateStringConstant(
319 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
320 Value assertLine =
321 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
322 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
323
324 // Insert function call to __assertfail.
325 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
326 assertFunc, c1};
327 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
328 arguments);
329 return success();
330 }
331};
332
333/// Import the GPU Ops to NVVM Patterns.
334#include "GPUToNVVM.cpp.inc"
335
336/// A pass that replaces all occurrences of GPU device operations with their
337/// corresponding NVVM equivalent.
338///
339/// This pass only handles device code and is not meant to be run on GPU host
340/// code.
341struct LowerGpuOpsToNVVMOpsPass final
342 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
343 using Base::Base;
344
345 void getDependentDialects(DialectRegistry &registry) const override {
346 Base::getDependentDialects(registry);
348 }
349
350 void runOnOperation() override {
351 gpu::GPUModuleOp m = getOperation();
352
353 // Request C wrapper emission.
354 for (auto func : m.getOps<func::FuncOp>()) {
355 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
356 UnitAttr::get(&getContext()));
357 }
358
359 // Customize the bitwidth used for the device side index computations.
360 LowerToLLVMOptions options(
361 m.getContext(),
362 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
363 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
364 options.overrideIndexBitwidth(indexBitwidth);
365 options.useBarePtrCallConv = useBarePtrCallConv;
366
367 // Apply in-dialect lowering. In-dialect lowering will replace
368 // ops which need to be lowered further, which is not supported by a
369 // single conversion pass.
370 {
371 RewritePatternSet patterns(m.getContext());
373 // Transform N-D vector.from_elements to 1-D vector.from_elements before
374 // conversion.
375 vector::populateVectorFromElementsUnrollPatterns(patterns);
376 if (failed(applyPatternsGreedily(m, std::move(patterns))))
377 return signalPassFailure();
378 }
379
380 LLVMTypeConverter converter(m.getContext(), options);
382 RewritePatternSet llvmPatterns(m.getContext());
383 LLVMConversionTarget target(getContext());
384
385 // Set higher benefit, so patterns will run before generic LLVM lowering.
386 populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
387 /*benefit=*/10);
388
389 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
390 allowedDialects.end());
391 for (Dialect *dialect : getContext().getLoadedDialects()) {
392 // Skip math patterns as nvvm needs custom math lowering.
393 if (isa<math::MathDialect>(dialect))
394 continue;
395
396 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
397 // Empty `allowedDialectsSet` means all dialects are allowed.
398 if (!allowedDialectsSet.empty() && !allowed)
399 continue;
400
401 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
402 if (!iface) {
403 // Error out if dialect was explicily specified but doesn't implement
404 // conversion interface.
405 if (allowed) {
406 m.emitError()
407 << "dialect does not implement ConvertToLLVMPatternInterface: "
408 << dialect->getNamespace();
409 return signalPassFailure();
410 }
411 continue;
412 }
413
414 iface->populateConvertToLLVMConversionPatterns(target, converter,
415 llvmPatterns);
416 }
417
418 populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
419 if (this->hasRedux)
420 populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
422 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
423 signalPassFailure();
424 }
425};
426
427} // namespace
428
430 target.addIllegalOp<func::FuncOp>();
431 target.addIllegalOp<cf::AssertOp>();
432 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
433 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
434 target.addIllegalDialect<gpu::GPUDialect>();
435 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
436 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439 LLVM::SincosOp, LLVM::SqrtOp>();
440
441 // TODO: Remove once we support replacing non-root ops.
442 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
443}
444
446 // NVVM uses alloca in the default address space to represent private
447 // memory allocations, so drop private annotations. NVVM uses address
448 // space 3 for shared memory. NVVM uses the default address space to
449 // represent global memory.
451 converter, [](gpu::AddressSpace space) -> unsigned {
452 switch (space) {
453 case gpu::AddressSpace::Global:
454 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
455 case gpu::AddressSpace::Workgroup:
456 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
457 case gpu::AddressSpace::Private:
458 return 0;
459 }
460 llvm_unreachable("unknown address space enum value");
461 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
462 });
463 // Lowering for MMAMatrixType.
464 converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
465 return convertMMAToLLVMType(type);
466 });
467}
468
469struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
471
472 LogicalResult
473 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
474 ConversionPatternRewriter &rewriter) const override {
475 Location loc = op.getLoc();
476 Value input = adaptor.getOperand();
477 Type inputType = input.getType();
478 auto convertedInput = maybeExt(input, rewriter);
479 auto computeType = convertedInput.getType();
480
481 StringRef sincosFunc;
482 if (isa<Float32Type>(computeType)) {
483 const arith::FastMathFlags flag = op.getFastmath();
484 const bool useApprox =
485 mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
486 sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
487 } else if (isa<Float64Type>(computeType)) {
488 sincosFunc = "__nv_sincos";
489 } else {
490 return rewriter.notifyMatchFailure(op,
491 "unsupported operand type for sincos");
492 }
493
494 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
495
496 Value sinPtr, cosPtr;
497 {
498 OpBuilder::InsertionGuard guard(rewriter);
499 auto *scope =
500 op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
501 assert(scope && "Expected op to be inside automatic allocation scope");
502 rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
503 auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
504 rewriter.getI32IntegerAttr(1));
505 sinPtr =
506 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
507 cosPtr =
508 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
509 }
510
511 createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
512 op);
513
514 auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
515 auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
516
517 rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
518 maybeTrunc(cosResult, inputType, rewriter)});
519 return success();
520 }
521
522private:
523 Value maybeExt(Value operand, PatternRewriter &rewriter) const {
524 if (isa<Float16Type, BFloat16Type>(operand.getType()))
525 return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
526 Float32Type::get(rewriter.getContext()),
527 operand);
528 return operand;
529 }
530
531 Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
532 if (operand.getType() != type)
533 return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
534 return operand;
535 }
536
537 void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
538 StringRef funcName, Value input, Value sinPtr,
539 Value cosPtr, Operation *op) const {
540 auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
541 auto ptrType = sinPtr.getType();
542
543 SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
544 auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
545
546 auto funcAttr = StringAttr::get(op->getContext(), funcName);
547 auto funcOp =
549
550 if (!funcOp) {
551 auto parentFunc = op->getParentOfType<FunctionOpInterface>();
552 assert(parentFunc && "expected there to be a parent function");
553 OpBuilder b(parentFunc);
554
555 auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
556 funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
557 }
558
559 SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
560 LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
561 }
562};
563
564template <typename OpTy>
565static void populateOpPatterns(const LLVMTypeConverter &converter,
567 PatternBenefit benefit, StringRef f32Func,
568 StringRef f64Func, StringRef f32ApproxFunc = "",
569 StringRef f16Func = "") {
570 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
571 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
572 f32ApproxFunc, f16Func,
573 /*i32Func=*/"", benefit);
574}
575
576template <typename OpTy>
577static void populateIntOpPatterns(const LLVMTypeConverter &converter,
579 PatternBenefit benefit, StringRef i32Func) {
580 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
581 patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
582 benefit);
583}
584
585template <typename OpTy>
588 PatternBenefit benefit,
589 StringRef f32Func, StringRef f64Func) {
590 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
591 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
592 /*i32Func=*/"", benefit);
593}
594
597 PatternBenefit benefit) {
598 patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
599}
600
603 PatternBenefit benefit) {
604 populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
605 "__nv_fmod");
607 "__nv_fmaxf", "__nv_fmax");
609 "__nv_fminf", "__nv_fmin");
610
611 populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
612 populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
613 "__nv_fabs");
614 populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
615 "__nv_acos");
616 populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
617 "__nv_acosh");
618 populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
619 "__nv_asin");
620 populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
621 "__nv_asinh");
622 populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
623 "__nv_atan");
624 populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
625 "__nv_atan2");
626 populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
627 "__nv_atanh");
628 populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
629 "__nv_cbrt");
630 populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
631 "__nv_ceil");
633 "__nv_copysignf", "__nv_copysign");
634 populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
635 "__nv_cos", "__nv_fast_cosf");
636 populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
637 "__nv_cosh");
638 populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
639 "__nv_erf");
640 populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
641 "__nv_erfc");
642 populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
643 "__nv_exp", "__nv_fast_expf");
644 populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
645 "__nv_exp2");
646 populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
647 "__nv_expm1");
648 populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
649 "__nv_floor");
650 populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
651 "__nv_fma");
652 // Note: libdevice uses a different name for 32-bit finite checking
654 "__nv_finitef", "__nv_isfinited");
655 populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
656 "__nv_isinfd");
657 populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
658 "__nv_isnand");
659 populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
660 "__nv_log", "__nv_fast_logf");
661 populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
662 "__nv_log10", "__nv_fast_log10f");
663 populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
664 "__nv_log1p");
665 populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
666 "__nv_log2", "__nv_fast_log2f");
667 populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
668 "__nv_pow", "__nv_fast_powf");
670 "__nv_powif", "__nv_powi");
671 populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
672 "__nv_round");
674 "__nv_rintf", "__nv_rint");
675 populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
676 "__nv_rsqrt");
677 populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
678 "__nv_sin", "__nv_fast_sinf");
679 populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
680 "__nv_sinh");
681 populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
682 "__nv_sqrt");
683 populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
684 "__nv_tan", "__nv_fast_tanf");
685 populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
686 "__nv_tanh");
687
688 // Custom pattern for sincos since it returns two values
689 patterns.add<SincosOpLowering>(converter, benefit);
690}
691
694 PatternBenefit benefit) {
697
698 // TODO: Pass benefit to generated patterns.
699 populateWithGenerated(patterns);
700
701 patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
702 converter, benefit);
703 patterns.add<
704 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
705 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
706 converter, IndexKind::Block, IntrType::Id, benefit);
707 patterns.add<
708 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
709 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
710 converter, IndexKind::Block, IntrType::Dim, benefit);
711 patterns.add<
712 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
713 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
714 converter, IndexKind::Other, IntrType::Id, benefit);
716 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
717 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
718 benefit);
720 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
721 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
722 converter, IndexKind::Other, IntrType::Id, benefit);
724 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
725 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
726 converter, IndexKind::Other, IntrType::Dim, benefit);
728 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
729 converter, IndexKind::Grid, IntrType::Id, benefit);
731 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
732 converter, IndexKind::Grid, IntrType::Dim, benefit);
733 patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
734 converter, benefit);
735
737 converter, NVVM::kSharedMemoryAlignmentBit, benefit);
738
739 // Explicitly drop memory space when lowering private memory
740 // attributions since NVVM models it as `alloca`s in the default
741 // memory space and does not support `alloca`s with addrspace(5).
743 converter,
745 /*allocaAddrSpace=*/0,
746 /*workgroupAddrSpace=*/
747 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
748 StringAttr::get(&converter.getContext(),
749 NVVM::NVVMDialect::getKernelFuncAttrName()),
750 StringAttr::get(&converter.getContext(),
751 NVVM::NVVMDialect::getMaxntidAttrName())},
752 benefit);
753
755}
756
757//===----------------------------------------------------------------------===//
758// NVVMTargetAttr convert to LLVM attr interface
759//===----------------------------------------------------------------------===//
760
761namespace {
762struct NVVMTargetConvertToLLVMAttrInterface
763 : public ConvertToLLVMAttrInterface::ExternalModel<
764 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
765 /// Configure GPU to NVVM.
766 void populateConvertToLLVMConversionPatterns(
768 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
769};
770} // namespace
771
772void NVVMTargetConvertToLLVMAttrInterface::
773 populateConvertToLLVMConversionPatterns(Attribute attr,
775 LLVMTypeConverter &typeConverter,
778 configureGpuToNVVMTypeConverter(typeConverter);
780}
781
783 registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
784 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
785 });
786}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
constexpr int kWarpSize
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
typename math::SincosOp::Adaptor OpAdaptor
Definition Pattern.h:211
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
Definition Location.h:60
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
A trait of region holding operations that define a new scope for automatic allocations,...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns that lower certain arith and math dialect ops to libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.