34 #include "../GPUCommon/GPUOpsLowering.h"
35 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
36 #include "../GPUCommon/OpToFuncCallLowering.h"
40 #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
41 #include "mlir/Conversion/Passes.h.inc"
49 static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
51 case gpu::ShuffleMode::XOR:
52 return NVVM::ShflKind::bfly;
53 case gpu::ShuffleMode::UP:
54 return NVVM::ShflKind::up;
55 case gpu::ShuffleMode::DOWN:
56 return NVVM::ShflKind::down;
57 case gpu::ShuffleMode::IDX:
58 return NVVM::ShflKind::idx;
60 llvm_unreachable(
"unknown shuffle mode");
63 static std::optional<NVVM::ReduxKind>
64 convertReduxKind(gpu::AllReduceOperation mode) {
66 case gpu::AllReduceOperation::ADD:
67 return NVVM::ReduxKind::ADD;
68 case gpu::AllReduceOperation::MUL:
70 case gpu::AllReduceOperation::MINSI:
71 return NVVM::ReduxKind::MIN;
74 case gpu::AllReduceOperation::MINNUMF:
75 return NVVM::ReduxKind::MIN;
76 case gpu::AllReduceOperation::MAXSI:
77 return NVVM::ReduxKind::MAX;
78 case gpu::AllReduceOperation::MAXUI:
80 case gpu::AllReduceOperation::MAXNUMF:
81 return NVVM::ReduxKind::MAX;
82 case gpu::AllReduceOperation::AND:
83 return NVVM::ReduxKind::AND;
84 case gpu::AllReduceOperation::OR:
85 return NVVM::ReduxKind::OR;
86 case gpu::AllReduceOperation::XOR:
87 return NVVM::ReduxKind::XOR;
88 case gpu::AllReduceOperation::MINIMUMF:
89 case gpu::AllReduceOperation::MAXIMUMF:
97 struct GPUSubgroupReduceOpLowering
102 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
104 if (op.getClusterSize())
106 op,
"lowering for clustered reduce not implemented");
108 if (!op.getUniform())
110 op,
"cannot be lowered to redux as the op must be run "
111 "uniformly (entire subgroup).");
112 if (!op.getValue().getType().isInteger(32))
115 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
116 if (!mode.has_value())
118 op,
"unsupported reduction mode for redux");
122 Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
124 auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
125 op.getValue(), mode.value(), offset);
127 rewriter.
replaceOp(op, reduxOp->getResult(0));
154 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
158 auto valueTy = adaptor.getValue().getType();
162 Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
163 Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
164 Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
165 Value numLeadInactiveLane = LLVM::SubOp::create(
166 rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
168 Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
169 numLeadInactiveLane);
171 if (op.getMode() == gpu::ShuffleMode::UP) {
173 maskAndClamp = numLeadInactiveLane;
176 maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
177 adaptor.getWidth(), one);
180 bool predIsUsed = !op->getResult(1).use_empty();
181 UnitAttr returnValueAndIsValidAttr =
nullptr;
182 Type resultTy = valueTy;
184 returnValueAndIsValidAttr = rewriter.
getUnitAttr();
185 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.
getContext(),
188 Value shfl = NVVM::ShflOp::create(
189 rewriter, loc, resultTy, activeMask, adaptor.getValue(),
190 adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
191 returnValueAndIsValidAttr);
193 Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
194 Value isActiveSrcLane =
195 LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
196 rewriter.
replaceOp(op, {shflValue, isActiveSrcLane});
208 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
210 auto loc = op->getLoc();
212 LLVM::ConstantRangeAttr bounds =
nullptr;
213 if (std::optional<APInt> upperBound = op.getUpperBound())
214 bounds = rewriter.
getAttr<LLVM::ConstantRangeAttr>(
215 32, 0, upperBound->getZExtValue());
217 bounds = rewriter.
getAttr<LLVM::ConstantRangeAttr>(
220 NVVM::LaneIdOp::create(rewriter, loc, rewriter.
getI32Type(), bounds);
223 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
224 if (indexBitwidth > 32) {
225 newOp = LLVM::SExtOp::create(
227 }
else if (indexBitwidth < 32) {
228 newOp = LLVM::TruncOp::create(
237 struct AssertOpToAssertfailLowering
242 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
253 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
255 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
257 moduleOp, loc, rewriter,
"__assertfail", assertfailType);
258 assertfailDecl.setPassthroughAttr(
270 Block *beforeBlock = assertOp->getBlock();
272 rewriter.
splitBlock(beforeBlock, assertOp->getIterator());
274 rewriter.
splitBlock(assertBlock, ++assertOp->getIterator());
276 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
279 cf::BranchOp::create(rewriter, loc, afterBlock);
286 StringRef fileName =
"(unknown)";
287 StringRef funcName =
"(unknown)";
288 int32_t fileLine = 0;
289 while (
auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
290 loc = callSiteLoc.getCallee();
291 if (
auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
292 fileName = fileLineColLoc.getFilename().strref();
293 fileLine = fileLineColLoc.getStartLine();
294 }
else if (
auto nameLoc = dyn_cast<NameLoc>(loc)) {
295 funcName = nameLoc.getName().strref();
296 if (
auto fileLineColLoc =
297 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
298 fileName = fileLineColLoc.getFilename().strref();
299 fileLine = fileLineColLoc.getStartLine();
304 auto getGlobal = [&](LLVM::GlobalOp global) {
306 Value globalPtr = LLVM::AddressOfOp::create(
308 global.getSymNameAttr());
310 LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
315 rewriter, loc, moduleOp, i8Type,
"assert_message_", assertOp.getMsg()));
317 rewriter, loc, moduleOp, i8Type,
"assert_file_", fileName));
319 rewriter, loc, moduleOp, i8Type,
"assert_func_", funcName));
321 LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
322 Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
334 #include "GPUToNVVM.cpp.inc"
341 struct LowerGpuOpsToNVVMOpsPass final
342 :
public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
346 Base::getDependentDialects(registry);
350 void runOnOperation()
override {
351 gpu::GPUModuleOp m = getOperation();
354 for (
auto func : m.getOps<func::FuncOp>()) {
355 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
362 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
364 options.overrideIndexBitwidth(indexBitwidth);
365 options.useBarePtrCallConv = useBarePtrCallConv;
375 vector::populateVectorFromElementsUnrollPatterns(
patterns);
377 return signalPassFailure();
389 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
390 allowedDialects.end());
393 if (isa<math::MathDialect>(dialect))
396 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
398 if (!allowedDialectsSet.empty() && !allowed)
401 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
407 <<
"dialect does not implement ConvertToLLVMPatternInterface: "
408 << dialect->getNamespace();
409 return signalPassFailure();
414 iface->populateConvertToLLVMConversionPatterns(target, converter,
435 target.
addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
436 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439 LLVM::SincosOp, LLVM::SqrtOp>();
442 target.
addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
451 converter, [](gpu::AddressSpace space) ->
unsigned {
453 case gpu::AddressSpace::Global:
454 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
455 case gpu::AddressSpace::Workgroup:
456 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
457 case gpu::AddressSpace::Private:
460 llvm_unreachable(
"unknown address space enum value");
461 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
476 Value input = adaptor.getOperand();
478 auto convertedInput = maybeExt(input, rewriter);
479 auto computeType = convertedInput.getType();
481 StringRef sincosFunc;
482 if (isa<Float32Type>(computeType)) {
483 const arith::FastMathFlags flag = op.getFastmath();
484 const bool useApprox =
485 mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
486 sincosFunc = useApprox ?
"__nv_fast_sincosf" :
"__nv_sincosf";
487 }
else if (isa<Float64Type>(computeType)) {
488 sincosFunc =
"__nv_sincos";
491 "unsupported operand type for sincos");
496 Value sinPtr, cosPtr;
501 assert(scope &&
"Expected op to be inside automatic allocation scope");
503 auto one = rewriter.
create<LLVM::ConstantOp>(
506 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
508 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
511 createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
514 auto sinResult = rewriter.
create<LLVM::LoadOp>(loc, computeType, sinPtr);
515 auto cosResult = rewriter.
create<LLVM::LoadOp>(loc, computeType, cosPtr);
517 rewriter.
replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
518 maybeTrunc(cosResult, inputType, rewriter)});
524 if (isa<Float16Type, BFloat16Type>(operand.
getType()))
525 return rewriter.
create<LLVM::FPExtOp>(
532 return rewriter.
create<LLVM::FPTruncOp>(operand.
getLoc(), type, operand);
537 StringRef funcName,
Value input,
Value sinPtr,
540 auto ptrType = sinPtr.
getType();
547 SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
551 assert(parentFunc &&
"expected there to be a parent function");
555 funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
559 rewriter.
create<LLVM::CallOp>(loc, funcOp, callOperands);
563 template <
typename OpTy>
567 StringRef f64Func, StringRef f32ApproxFunc =
"",
568 StringRef f16Func =
"") {
571 f32ApproxFunc, f16Func,
575 template <
typename OpTy>
584 template <
typename OpTy>
588 StringRef f32Func, StringRef f64Func) {
597 patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
603 populateOpPatterns<arith::RemFOp>(converter,
patterns, benefit,
"__nv_fmodf",
605 populateOpPatterns<arith::MaxNumFOp>(converter,
patterns, benefit,
606 "__nv_fmaxf",
"__nv_fmax");
607 populateOpPatterns<arith::MinNumFOp>(converter,
patterns, benefit,
608 "__nv_fminf",
"__nv_fmin");
610 populateIntOpPatterns<math::AbsIOp>(converter,
patterns, benefit,
"__nv_abs");
611 populateOpPatterns<math::AbsFOp>(converter,
patterns, benefit,
"__nv_fabsf",
613 populateOpPatterns<math::AcosOp>(converter,
patterns, benefit,
"__nv_acosf",
615 populateOpPatterns<math::AcoshOp>(converter,
patterns, benefit,
"__nv_acoshf",
617 populateOpPatterns<math::AsinOp>(converter,
patterns, benefit,
"__nv_asinf",
619 populateOpPatterns<math::AsinhOp>(converter,
patterns, benefit,
"__nv_asinhf",
621 populateOpPatterns<math::AtanOp>(converter,
patterns, benefit,
"__nv_atanf",
623 populateOpPatterns<math::Atan2Op>(converter,
patterns, benefit,
"__nv_atan2f",
625 populateOpPatterns<math::AtanhOp>(converter,
patterns, benefit,
"__nv_atanhf",
627 populateOpPatterns<math::CbrtOp>(converter,
patterns, benefit,
"__nv_cbrtf",
629 populateOpPatterns<math::CeilOp>(converter,
patterns, benefit,
"__nv_ceilf",
631 populateOpPatterns<math::CopySignOp>(converter,
patterns, benefit,
632 "__nv_copysignf",
"__nv_copysign");
633 populateOpPatterns<math::CosOp>(converter,
patterns, benefit,
"__nv_cosf",
634 "__nv_cos",
"__nv_fast_cosf");
635 populateOpPatterns<math::CoshOp>(converter,
patterns, benefit,
"__nv_coshf",
637 populateOpPatterns<math::ErfOp>(converter,
patterns, benefit,
"__nv_erff",
639 populateOpPatterns<math::ErfcOp>(converter,
patterns, benefit,
"__nv_erfcf",
641 populateOpPatterns<math::ExpOp>(converter,
patterns, benefit,
"__nv_expf",
642 "__nv_exp",
"__nv_fast_expf");
643 populateOpPatterns<math::Exp2Op>(converter,
patterns, benefit,
"__nv_exp2f",
645 populateOpPatterns<math::ExpM1Op>(converter,
patterns, benefit,
"__nv_expm1f",
647 populateOpPatterns<math::FloorOp>(converter,
patterns, benefit,
"__nv_floorf",
649 populateOpPatterns<math::FmaOp>(converter,
patterns, benefit,
"__nv_fmaf",
652 populateOpPatterns<math::IsFiniteOp>(converter,
patterns, benefit,
653 "__nv_finitef",
"__nv_isfinited");
654 populateOpPatterns<math::IsInfOp>(converter,
patterns, benefit,
"__nv_isinff",
656 populateOpPatterns<math::IsNaNOp>(converter,
patterns, benefit,
"__nv_isnanf",
658 populateOpPatterns<math::LogOp>(converter,
patterns, benefit,
"__nv_logf",
659 "__nv_log",
"__nv_fast_logf");
660 populateOpPatterns<math::Log10Op>(converter,
patterns, benefit,
"__nv_log10f",
661 "__nv_log10",
"__nv_fast_log10f");
662 populateOpPatterns<math::Log1pOp>(converter,
patterns, benefit,
"__nv_log1pf",
664 populateOpPatterns<math::Log2Op>(converter,
patterns, benefit,
"__nv_log2f",
665 "__nv_log2",
"__nv_fast_log2f");
666 populateOpPatterns<math::PowFOp>(converter,
patterns, benefit,
"__nv_powf",
667 "__nv_pow",
"__nv_fast_powf");
668 populateFloatIntOpPatterns<math::FPowIOp>(converter,
patterns, benefit,
669 "__nv_powif",
"__nv_powi");
670 populateOpPatterns<math::RoundOp>(converter,
patterns, benefit,
"__nv_roundf",
672 populateOpPatterns<math::RoundEvenOp>(converter,
patterns, benefit,
673 "__nv_rintf",
"__nv_rint");
674 populateOpPatterns<math::RsqrtOp>(converter,
patterns, benefit,
"__nv_rsqrtf",
676 populateOpPatterns<math::SinOp>(converter,
patterns, benefit,
"__nv_sinf",
677 "__nv_sin",
"__nv_fast_sinf");
678 populateOpPatterns<math::SinhOp>(converter,
patterns, benefit,
"__nv_sinhf",
680 populateOpPatterns<math::SqrtOp>(converter,
patterns, benefit,
"__nv_sqrtf",
682 populateOpPatterns<math::TanOp>(converter,
patterns, benefit,
"__nv_tanf",
683 "__nv_tan",
"__nv_fast_tanf");
684 populateOpPatterns<math::TanhOp>(converter,
patterns, benefit,
"__nv_tanhf",
688 patterns.add<SincosOpLowering>(converter, benefit);
704 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
705 converter, IndexKind::Block, IntrType::Id, benefit);
708 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
709 converter, IndexKind::Block, IntrType::Dim, benefit);
712 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
713 converter, IndexKind::Other, IntrType::Id, benefit);
715 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
716 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
719 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
720 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
721 converter, IndexKind::Other, IntrType::Id, benefit);
723 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
724 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
725 converter, IndexKind::Other, IntrType::Dim, benefit);
727 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
728 converter, IndexKind::Grid, IntrType::Id, benefit);
730 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
731 converter, IndexKind::Grid, IntrType::Dim, benefit);
746 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared),
748 NVVM::NVVMDialect::getKernelFuncAttrName()),
750 NVVM::NVVMDialect::getMaxntidAttrName())},
761 struct NVVMTargetConvertToLLVMAttrInterface
762 :
public ConvertToLLVMAttrInterface::ExternalModel<
763 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
765 void populateConvertToLLVMConversionPatterns(
771 void NVVMTargetConvertToLLVMAttrInterface::
772 populateConvertToLLVMConversionPatterns(
Attribute attr,
783 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
static MLIRContext * getContext(OpFoldResult val)
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func)
static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef f32Func, StringRef f64Func, StringRef f32ApproxFunc="", StringRef f16Func="")
static void populateIntOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit, StringRef i32Func)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
An instance of this location represents a tuple of file, line number, and column number.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
LocationAttr findInstanceOfOrUnknown()
Return an instance of the given location type if one is nested under the current location else return...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
constexpr int kSharedMemoryAlignmentBit
void registerConvertGpuToNVVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMAttrInterface interface on the NVVM::NVVMTargetAttr attribute.
Include the generated interface declarations.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target)
Configure target to convert from the GPU dialect to NVVM.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc, Operation *moduleOp, Type llvmI8, StringRef namePrefix, StringRef str, uint64_t alignment=0, unsigned addrSpace=0)
Create a global that contains the given string.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns that lower certain arith and math dialect ops to libdevice calls.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Lowering for gpu.dynamic.shared.memory to LLVM dialect.
Lowering of gpu.printf to a vprintf standard library.
Rewriting that replaces SourceOp with a CallOp to f32Func or f64Func or f32ApproxFunc or f16Func or i...
Unrolls SourceOp to array/vector elements.