MLIR 23.0.0git
LowerGpuOpsToNVVMOps.cpp
Go to the documentation of this file.
1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
35
39#include <optional>
40
41namespace mlir {
42#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
43#include "mlir/Conversion/Passes.h.inc"
44} // namespace mlir
45
46using namespace mlir;
47
48namespace {
49
50/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
51static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
52 switch (mode) {
53 case gpu::ShuffleMode::XOR:
54 return NVVM::ShflKind::bfly;
55 case gpu::ShuffleMode::UP:
56 return NVVM::ShflKind::up;
57 case gpu::ShuffleMode::DOWN:
58 return NVVM::ShflKind::down;
59 case gpu::ShuffleMode::IDX:
60 return NVVM::ShflKind::idx;
61 }
62 llvm_unreachable("unknown shuffle mode");
63}
64
65static std::optional<NVVM::ReductionKind>
66convertToNVVMReductionKind(gpu::AllReduceOperation mode) {
67 switch (mode) {
68 case gpu::AllReduceOperation::ADD:
69 return NVVM::ReductionKind::ADD;
70 case gpu::AllReduceOperation::MUL:
71 return std::nullopt;
72 case gpu::AllReduceOperation::MINSI:
73 return NVVM::ReductionKind::MIN;
74 case gpu::AllReduceOperation::MINUI:
75 return std::nullopt;
76 case gpu::AllReduceOperation::MINNUMF:
77 return NVVM::ReductionKind::MIN;
78 case gpu::AllReduceOperation::MAXSI:
79 return NVVM::ReductionKind::MAX;
80 case gpu::AllReduceOperation::MAXUI:
81 return std::nullopt;
82 case gpu::AllReduceOperation::MAXNUMF:
83 return NVVM::ReductionKind::MAX;
84 case gpu::AllReduceOperation::AND:
85 return NVVM::ReductionKind::AND;
86 case gpu::AllReduceOperation::OR:
87 return NVVM::ReductionKind::OR;
88 case gpu::AllReduceOperation::XOR:
89 return NVVM::ReductionKind::XOR;
90 case gpu::AllReduceOperation::MINIMUMF:
91 case gpu::AllReduceOperation::MAXIMUMF:
92 return std::nullopt;
93 }
94 return std::nullopt;
95}
96
97/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
98/// must be run by the entire subgroup, otherwise it is undefined behaviour.
99struct GPUSubgroupReduceOpLowering
100 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
101 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
102 LogicalResult
103
104 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter) const override {
106 if (op.getClusterSize())
107 return rewriter.notifyMatchFailure(
108 op, "lowering for clustered reduce not implemented");
109
110 if (!op.getUniform())
111 return rewriter.notifyMatchFailure(
112 op, "cannot be lowered to redux as the op must be run "
113 "uniformly (entire subgroup).");
114 if (!op.getValue().getType().isInteger(32))
115 return rewriter.notifyMatchFailure(op, "unsupported data type");
116
117 std::optional<NVVM::ReductionKind> mode =
118 convertToNVVMReductionKind(op.getOp());
119 if (!mode.has_value())
120 return rewriter.notifyMatchFailure(
121 op, "unsupported reduction mode for redux");
122
123 Location loc = op->getLoc();
124 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
125 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
126
127 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
128 op.getValue(), mode.value(), offset);
129
130 rewriter.replaceOp(op, reduxOp->getResult(0));
131 return success();
132 }
133};
134
135struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
136 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
137
138 /// Lowers a shuffle to the corresponding NVVM op.
139 ///
140 /// Convert the `width` argument into an activeMask (a bitmask which specifies
141 /// which threads participate in the shuffle) and a maskAndClamp (specifying
142 /// the highest lane which participates in the shuffle).
143 ///
144 /// %one = llvm.constant(1 : i32) : i32
145 /// %minus_one = llvm.constant(-1 : i32) : i32
146 /// %thirty_two = llvm.constant(32 : i32) : i32
147 /// %num_lanes = llvm.sub %thirty_two, %width : i32
148 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
149 /// %mask_and_clamp = llvm.sub %width, %one : i32
150 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
151 /// %mask_and_clamp : !llvm<"{ float, i1 }">
152 /// %shfl_value = llvm.extractvalue %shfl[0] :
153 /// !llvm<"{ float, i1 }">
154 /// %shfl_pred = llvm.extractvalue %shfl[1] :
155 /// !llvm<"{ float, i1 }">
156 LogicalResult
157 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter) const override {
159 Location loc = op->getLoc();
160
161 auto valueTy = adaptor.getValue().getType();
162 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
163 auto predTy = IntegerType::get(rewriter.getContext(), 1);
164
165 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
166 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
167 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
168 Value numLeadInactiveLane = LLVM::SubOp::create(
169 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
170 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
171 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
172 numLeadInactiveLane);
173 Value maskAndClamp;
174 if (op.getMode() == gpu::ShuffleMode::UP) {
175 // Clamp lane: `32 - activeWidth`
176 maskAndClamp = numLeadInactiveLane;
177 } else {
178 // Clamp lane: `activeWidth - 1`
179 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
180 adaptor.getWidth(), one);
181 }
182
183 bool predIsUsed = !op->getResult(1).use_empty();
184 UnitAttr returnValueAndIsValidAttr = nullptr;
185 Type resultTy = valueTy;
186 if (predIsUsed) {
187 returnValueAndIsValidAttr = rewriter.getUnitAttr();
188 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
189 {valueTy, predTy});
190 }
191 Value shfl = NVVM::ShflOp::create(
192 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
193 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
194 returnValueAndIsValidAttr);
195 if (predIsUsed) {
196 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
197 Value isActiveSrcLane =
198 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
199 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
200 } else {
201 rewriter.replaceOp(op, {shfl, nullptr});
202 }
203 return success();
204 }
205};
206
207struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
208 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
209
210 LogicalResult
211 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
212 ConversionPatternRewriter &rewriter) const override {
213 auto loc = op->getLoc();
214 MLIRContext *context = rewriter.getContext();
215 LLVM::ConstantRangeAttr bounds = nullptr;
216 if (std::optional<APInt> upperBound = op.getUpperBound())
217 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
218 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
219 else
220 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
221 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
222 Value newOp =
223 NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
224 // Truncate or extend the result depending on the index bitwidth specified
225 // by the LLVMTypeConverter options.
226 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
227 if (indexBitwidth > 32) {
228 newOp = LLVM::SExtOp::create(
229 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
230 } else if (indexBitwidth < 32) {
231 newOp = LLVM::TruncOp::create(
232 rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
233 }
234 rewriter.replaceOp(op, {newOp});
235 return success();
236 }
237};
238
239/// Lowering of cf.assert into a conditional __assertfail.
240struct AssertOpToAssertfailLowering
241 : public ConvertOpToLLVMPattern<cf::AssertOp> {
242 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
243
244 LogicalResult
245 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 MLIRContext *ctx = rewriter.getContext();
248 Location loc = assertOp.getLoc();
249 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
250 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
251 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
252 Type ptrType = LLVM::LLVMPointerType::get(ctx);
253 Type voidType = LLVM::LLVMVoidType::get(ctx);
254
255 // Find or create __assertfail function declaration.
256 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
257 auto assertfailType = LLVM::LLVMFunctionType::get(
258 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
259 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
260 moduleOp, loc, rewriter, "__assertfail", assertfailType);
261 assertfailDecl.setPassthroughAttr(
262 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
263
264 // Split blocks and insert conditional branch.
265 // ^before:
266 // ...
267 // cf.cond_br %condition, ^after, ^assert
268 // ^assert:
269 // cf.assert
270 // cf.br ^after
271 // ^after:
272 // ...
273 Block *beforeBlock = assertOp->getBlock();
274 Block *assertBlock =
275 rewriter.splitBlock(beforeBlock, assertOp->getIterator());
276 Block *afterBlock =
277 rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
278 rewriter.setInsertionPointToEnd(beforeBlock);
279 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
280 assertBlock);
281 rewriter.setInsertionPointToEnd(assertBlock);
282 cf::BranchOp::create(rewriter, loc, afterBlock);
283
284 // Continue cf.assert lowering.
285 rewriter.setInsertionPoint(assertOp);
286
287 // Populate file name, file number and function name from the location of
288 // the AssertOp.
289 StringRef fileName = "(unknown)";
290 StringRef funcName = "(unknown)";
291 int32_t fileLine = 0;
292 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
293 loc = callSiteLoc.getCallee();
294 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
295 fileName = fileLineColLoc.getFilename().strref();
296 fileLine = fileLineColLoc.getStartLine();
297 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
298 funcName = nameLoc.getName().strref();
299 if (auto fileLineColLoc =
300 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
301 fileName = fileLineColLoc.getFilename().strref();
302 fileLine = fileLineColLoc.getStartLine();
303 }
304 }
305
306 // Create constants.
307 auto getGlobal = [&](LLVM::GlobalOp global) {
308 // Get a pointer to the format string's first element.
309 Value globalPtr = LLVM::AddressOfOp::create(
310 rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
311 global.getSymNameAttr());
312 Value start =
313 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
314 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
315 return start;
316 };
317 Value assertMessage = getGlobal(getOrCreateStringConstant(
318 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
319 Value assertFile = getGlobal(getOrCreateStringConstant(
320 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
321 Value assertFunc = getGlobal(getOrCreateStringConstant(
322 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
323 Value assertLine =
324 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
325 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
326
327 // Insert function call to __assertfail.
328 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
329 assertFunc, c1};
330 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
331 arguments);
332 return success();
333 }
334};
335
336/// Import the GPU Ops to NVVM Patterns.
337#include "GPUToNVVM.cpp.inc"
338
339/// A pass that replaces all occurrences of GPU device operations with their
340/// corresponding NVVM equivalent.
341///
342/// This pass only handles device code and is not meant to be run on GPU host
343/// code.
344struct LowerGpuOpsToNVVMOpsPass final
345 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
346 using Base::Base;
347
348 void getDependentDialects(DialectRegistry &registry) const override {
349 Base::getDependentDialects(registry);
351 }
352
353 void runOnOperation() override {
354 gpu::GPUModuleOp m = getOperation();
355
356 // Request C wrapper emission.
357 for (auto func : m.getOps<func::FuncOp>()) {
358 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
359 UnitAttr::get(&getContext()));
360 }
361
362 // Customize the bitwidth used for the device side index computations.
363 LowerToLLVMOptions options(
364 m.getContext(),
365 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
366 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
367 options.overrideIndexBitwidth(indexBitwidth);
368 options.useBarePtrCallConv = useBarePtrCallConv;
369
370 // Apply in-dialect lowering. In-dialect lowering will replace
371 // ops which need to be lowered further, which is not supported by a
372 // single conversion pass.
373 {
374 RewritePatternSet patterns(m.getContext());
376 // Transform N-D vector.from_elements to 1-D vector.from_elements before
377 // conversion.
378 vector::populateVectorFromElementsUnrollPatterns(patterns);
379 if (failed(applyPatternsGreedily(m, std::move(patterns))))
380 return signalPassFailure();
381 }
382
383 LLVMTypeConverter converter(m.getContext(), options);
385 RewritePatternSet llvmPatterns(m.getContext());
386 LLVMConversionTarget target(getContext());
387
388 // Set higher benefit, so patterns will run before generic LLVM lowering.
389 populateGpuToNVVMConversionPatterns(converter, llvmPatterns,
390 /*benefit=*/10);
391
392 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
393 allowedDialects.end());
394 for (Dialect *dialect : getContext().getLoadedDialects()) {
395 // Skip math patterns as nvvm needs custom math lowering.
396 if (isa<math::MathDialect>(dialect))
397 continue;
398
399 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
400 // Empty `allowedDialectsSet` means all dialects are allowed.
401 if (!allowedDialectsSet.empty() && !allowed)
402 continue;
403
404 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
405 if (!iface) {
406 // Error out if dialect was explicily specified but doesn't implement
407 // conversion interface.
408 if (allowed) {
409 m.emitError()
410 << "dialect does not implement ConvertToLLVMPatternInterface: "
411 << dialect->getNamespace();
412 return signalPassFailure();
413 }
414 continue;
415 }
416
417 iface->populateConvertToLLVMConversionPatterns(target, converter,
418 llvmPatterns);
419 }
420
421 populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
422 if (this->hasRedux)
423 populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
425 ConversionConfig config;
426 config.allowPatternRollback = allowPatternRollback;
427 if (failed(
428 applyPartialConversion(m, target, std::move(llvmPatterns), config)))
429 signalPassFailure();
430 }
431};
432
433} // namespace
434
436 target.addIllegalOp<func::FuncOp>();
437 target.addIllegalOp<cf::AssertOp>();
438 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
439 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
440 target.addIllegalDialect<gpu::GPUDialect>();
441 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
442 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
443 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
444 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
445 LLVM::SincosOp, LLVM::SqrtOp>();
446
447 // TODO: Remove once we support replacing non-root ops.
448 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
449}
450
453
454 // Lowering for MMAMatrixType.
455 converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
456 return convertMMAToLLVMType(type);
457 });
458}
459
461 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
462 PatternBenefit benefit) {
463 patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
464}
465
467 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
468 PatternBenefit benefit) {
471
472 // TODO: Pass benefit to generated patterns.
473 populateWithGenerated(patterns);
474
475 patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
476 converter, benefit);
477 patterns.add<
478 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
479 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
480 converter, IndexKind::Block, IntrType::Id, benefit);
481 patterns.add<
482 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
483 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
484 converter, IndexKind::Block, IntrType::Dim, benefit);
485 patterns.add<
486 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
487 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
488 converter, IndexKind::Other, IntrType::Id, benefit);
490 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
491 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
492 benefit);
494 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
495 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
496 converter, IndexKind::Cluster, IntrType::Id, benefit);
498 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
499 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
500 converter, IndexKind::Cluster, IntrType::Dim, benefit);
502 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
503 converter, IndexKind::Grid, IntrType::Id, benefit);
505 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
506 converter, IndexKind::Grid, IntrType::Dim, benefit);
507 patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
508 converter, benefit);
509
511 converter, NVVM::kSharedMemoryAlignmentBit, benefit);
512
513 // Explicitly drop memory space when lowering private memory
514 // attributions since NVVM models it as `alloca`s in the default
515 // memory space and does not support `alloca`s with addrspace(5).
516 patterns.add<GPUFuncOpLowering>(
517 converter,
519 /*allocaAddrSpace=*/0,
520 /*workgroupAddrSpace=*/
521 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
522 StringAttr::get(&converter.getContext(),
523 NVVM::NVVMDialect::getKernelFuncAttrName()),
524 StringAttr::get(&converter.getContext(),
525 NVVM::NVVMDialect::getMaxntidAttrName()),
526 StringAttr::get(&converter.getContext(),
527 NVVM::NVVMDialect::getClusterDimAttrName())},
528 benefit);
529
530 populateLibDeviceConversionPatterns(converter, patterns, benefit);
531}
532
533//===----------------------------------------------------------------------===//
534// NVVMTargetAttr convert to LLVM attr interface
535//===----------------------------------------------------------------------===//
536
537namespace {
538struct NVVMTargetConvertToLLVMAttrInterface
539 : public ConvertToLLVMAttrInterface::ExternalModel<
540 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
541 /// Configure GPU to NVVM.
542 void populateConvertToLLVMConversionPatterns(
544 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
545};
546} // namespace
547
548void NVVMTargetConvertToLLVMAttrInterface::
549 populateConvertToLLVMConversionPatterns(Attribute attr,
551 LLVMTypeConverter &typeConverter,
552 RewritePatternSet &patterns) const {
554 configureGpuToNVVMTypeConverter(typeConverter);
555 populateGpuToNVVMConversionPatterns(typeConverter, patterns);
556}
557
559 registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
560 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
561 });
562}
return success()
b getContext())
constexpr int kWarpSize
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
constexpr int kSharedMemoryAlignmentBit
Definition NVVMDialect.h:49
void registerConvertGpuToNVVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition Passes.h:91
Type convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to NVVM libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.