32template <
typename SourceOp, spirv::BuiltIn builtin>
33class LaunchConfigConversion :
public OpConversionPattern<SourceOp> {
35 using OpConversionPattern<SourceOp>::OpConversionPattern;
38 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
39 ConversionPatternRewriter &rewriter)
const override;
44template <
typename SourceOp, spirv::BuiltIn builtin>
45class SingleDimLaunchConfigConversion :
public OpConversionPattern<SourceOp> {
47 using OpConversionPattern<SourceOp>::OpConversionPattern;
50 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter)
const override;
58class WorkGroupSizeConversion :
public OpConversionPattern<gpu::BlockDimOp> {
62 : OpConversionPattern(typeConverter, context, 10) {}
65 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter)
const override;
70class GPUFuncOpConversion final :
public OpConversionPattern<gpu::GPUFuncOp> {
75 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter)
const override;
83class GPUModuleConversion final :
public OpConversionPattern<gpu::GPUModuleOp> {
88 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter)
const override;
94class GPUReturnOpConversion final :
public OpConversionPattern<gpu::ReturnOp> {
99 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter)
const override;
104class GPUBarrierConversion final :
public OpConversionPattern<gpu::BarrierOp> {
109 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter)
const override;
114class GPUShuffleConversion final :
public OpConversionPattern<gpu::ShuffleOp> {
119 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter)
const override;
124class GPURotateConversion final :
public OpConversionPattern<gpu::RotateOp> {
129 matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter)
const override;
135class GPUSubgroupBroadcastConversion final
136 :
public OpConversionPattern<gpu::SubgroupBroadcastOp> {
141 matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter)
const override;
145class GPUBallotConversion final :
public OpConversionPattern<gpu::BallotOp> {
150 matchAndRewrite(gpu::BallotOp ballotOp, OpAdaptor adaptor,
151 ConversionPatternRewriter &rewriter)
const override;
154class GPUPrintfConversion final :
public OpConversionPattern<gpu::PrintfOp> {
159 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
160 ConversionPatternRewriter &rewriter)
const override;
169template <
typename SourceOp, spirv::BuiltIn builtin>
170LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
171 SourceOp op,
typename SourceOp::Adaptor adaptor,
172 ConversionPatternRewriter &rewriter)
const {
173 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
174 Type indexType = typeConverter->getIndexType();
188 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
189 Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
193 Value dim = spirv::CompositeExtractOp::create(
194 rewriter, op.getLoc(), builtinType,
vector,
195 rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
196 if (forShader && builtinType != indexType)
197 dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
198 rewriter.replaceOp(op, dim);
202template <
typename SourceOp, spirv::BuiltIn builtin>
204SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
205 SourceOp op,
typename SourceOp::Adaptor adaptor,
206 ConversionPatternRewriter &rewriter)
const {
207 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
208 Type indexType = typeConverter->getIndexType();
209 Type i32Type = rewriter.getIntegerType(32);
221 if (i32Type != indexType)
222 builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
224 rewriter.replaceOp(op, builtinValue);
228LogicalResult WorkGroupSizeConversion::matchAndRewrite(
229 gpu::BlockDimOp op, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter)
const {
232 if (!workGroupSizeAttr)
236 workGroupSizeAttr.
asArrayRef()[
static_cast<int32_t
>(op.getDimension())];
238 getTypeConverter()->convertType(op.getResult().getType());
241 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
242 op, convertedType, IntegerAttr::get(convertedType, val));
253 ConversionPatternRewriter &rewriter,
254 spirv::EntryPointABIAttr entryPointInfo,
256 auto fnType = funcOp.getFunctionType();
257 if (fnType.getNumResults()) {
258 funcOp.emitError(
"SPIR-V lowering only supports entry functions"
259 "with no return values right now");
262 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
264 "lowering as entry functions requires ABI info for all arguments "
271 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
273 for (
const auto &argType :
274 enumerate(funcOp.getFunctionType().getInputs())) {
275 auto convertedType = typeConverter.convertType(argType.value());
278 signatureConverter.addInputs(argType.index(), convertedType);
281 auto newFuncOp = spirv::FuncOp::create(
282 rewriter, funcOp.getLoc(), funcOp.getName(),
283 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
284 for (
const auto &namedAttr : funcOp->getAttrs()) {
285 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
288 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
291 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
293 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
294 &signatureConverter)))
296 rewriter.eraseOp(funcOp);
300 for (
auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
301 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
317 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
323 std::optional<spirv::StorageClass> sc;
324 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
325 sc = spirv::StorageClass::StorageBuffer;
332LogicalResult GPUFuncOpConversion::matchAndRewrite(
333 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
334 ConversionPatternRewriter &rewriter)
const {
335 if (!gpu::GPUDialect::isKernel(funcOp))
338 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
339 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
343 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
345 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
349 "match failure: missing 'spirv.interface_var_abi' attribute at "
354 argABI.push_back(abiAttr);
359 if (!entryPointAttr) {
361 "match failure: missing 'spirv.entry_point_abi' attribute");
365 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
368 newFuncOp->removeAttr(
369 rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
377LogicalResult GPUModuleConversion::matchAndRewrite(
378 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
379 ConversionPatternRewriter &rewriter)
const {
380 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
381 const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
383 targetEnv, typeConverter->getOptions().use64bitIndex);
386 return moduleOp.emitRemark(
387 "cannot deduce memory model from 'spirv.target_env'");
390 std::string spvModuleName = (
kSPIRVModule + moduleOp.getName()).str();
391 auto spvModule = spirv::ModuleOp::create(
392 rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
393 StringRef(spvModuleName));
396 Region &spvModuleRegion = spvModule.getRegion();
397 rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
398 spvModuleRegion.
begin());
400 rewriter.eraseBlock(&spvModuleRegion.
back());
406 if (
auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
409 if (
ArrayAttr targets = moduleOp.getTargetsAttr()) {
410 for (Attribute targetAttr : targets)
411 if (
auto spirvTargetEnvAttr =
412 dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
418 rewriter.eraseOp(moduleOp);
426LogicalResult GPUReturnOpConversion::matchAndRewrite(
427 gpu::ReturnOp returnOp, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter)
const {
429 if (!adaptor.getOperands().empty())
432 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
440LogicalResult GPUBarrierConversion::matchAndRewrite(
441 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
442 ConversionPatternRewriter &rewriter)
const {
445 auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
447 auto memorySemantics = spirv::MemorySemanticsAttr::get(
448 context, spirv::MemorySemantics::WorkgroupMemory |
449 spirv::MemorySemantics::AcquireRelease);
450 rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
459LogicalResult GPUShuffleConversion::matchAndRewrite(
460 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
461 ConversionPatternRewriter &rewriter)
const {
465 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
466 unsigned subgroupSize =
468 IntegerAttr widthAttr;
470 widthAttr.getValue().getZExtValue() != subgroupSize)
471 return rewriter.notifyMatchFailure(
472 shuffleOp,
"shuffle width and target subgroup size mismatch");
474 assert(!adaptor.getOffset().getType().isSignedInteger() &&
475 "shuffle offset must be a signless/unsigned integer");
477 Location loc = shuffleOp.getLoc();
478 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
482 switch (shuffleOp.getMode()) {
483 case gpu::ShuffleMode::XOR: {
484 result = spirv::GroupNonUniformShuffleXorOp::create(
485 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
486 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
487 shuffleOp.getLoc(), rewriter);
490 case gpu::ShuffleMode::IDX: {
491 result = spirv::GroupNonUniformShuffleOp::create(
492 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
493 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
494 shuffleOp.getLoc(), rewriter);
497 case gpu::ShuffleMode::DOWN: {
498 result = spirv::GroupNonUniformShuffleDownOp::create(
499 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
501 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
503 arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
504 validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
505 resultLaneId, adaptor.getWidth());
508 case gpu::ShuffleMode::UP: {
509 result = spirv::GroupNonUniformShuffleUpOp::create(
510 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
512 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
514 arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
515 auto i32Type = rewriter.getIntegerType(32);
516 validVal = arith::CmpIOp::create(
517 rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
518 arith::ConstantOp::create(rewriter, loc, i32Type,
519 rewriter.getIntegerAttr(i32Type, 0)));
524 rewriter.replaceOp(shuffleOp, {
result, validVal});
532LogicalResult GPURotateConversion::matchAndRewrite(
533 gpu::RotateOp rotateOp, OpAdaptor adaptor,
534 ConversionPatternRewriter &rewriter)
const {
535 const spirv::TargetEnv &targetEnv =
536 getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
537 unsigned subgroupSize =
539 unsigned width = rotateOp.getWidth();
540 if (width > subgroupSize)
541 return rewriter.notifyMatchFailure(
542 rotateOp,
"rotate width is larger than target subgroup size");
544 Location loc = rotateOp.getLoc();
545 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
547 arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
549 arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
550 Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
551 rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
553 if (width == subgroupSize) {
554 validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
556 IntegerAttr widthAttr = adaptor.getWidthAttr();
557 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
558 validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
562 rewriter.replaceOp(rotateOp, {rotateResult, validVal});
570LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
571 gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
572 ConversionPatternRewriter &rewriter)
const {
573 Location loc = op.getLoc();
574 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
577 switch (op.getBroadcastType()) {
578 case gpu::BroadcastType::specific_lane:
579 result = spirv::GroupNonUniformBroadcastOp::create(
580 rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
582 case gpu::BroadcastType::first_active_lane:
583 result = spirv::GroupNonUniformBroadcastFirstOp::create(
584 rewriter, loc, scope, adaptor.getSrc());
588 rewriter.replaceOp(op,
result);
592LogicalResult GPUBallotConversion::matchAndRewrite(
593 gpu::BallotOp ballotOp, OpAdaptor adaptor,
594 ConversionPatternRewriter &rewriter)
const {
595 Location loc = ballotOp.getLoc();
596 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
597 auto int32Type = rewriter.getI32Type();
598 auto vec4i32Type = VectorType::get({4}, int32Type);
601 Value ballot = spirv::GroupNonUniformBallotOp::create(
602 rewriter, loc, vec4i32Type, scope, adaptor.getPredicate());
604 auto intType = cast<IntegerType>(ballotOp.getType());
605 unsigned width = intType.getWidth();
609 spirv::CompositeExtractOp::create(rewriter, loc, ballot, {0});
610 rewriter.replaceOp(ballotOp,
result);
611 }
else if (width == 64) {
613 Value low = spirv::CompositeExtractOp::create(rewriter, loc, ballot, {0});
614 Value high = spirv::CompositeExtractOp::create(rewriter, loc, ballot, {1});
616 auto int64Type = rewriter.getI64Type();
617 Value lowExt = spirv::UConvertOp::create(rewriter, loc, int64Type, low);
618 Value highExt = spirv::UConvertOp::create(rewriter, loc, int64Type, high);
620 Value shift32 = spirv::ConstantOp::create(
621 rewriter, loc, int64Type, rewriter.getIntegerAttr(int64Type, 32));
623 spirv::ShiftLeftLogicalOp::create(rewriter, loc, highExt, shift32);
626 spirv::BitwiseOrOp::create(rewriter, loc, lowExt, highShifted);
627 rewriter.replaceOp(ballotOp,
result);
629 return rewriter.notifyMatchFailure(
630 ballotOp,
"only i32 and i64 result types are supported for SPIR-V");
640template <
typename UniformOp,
typename NonUniformOp>
642 Value arg,
bool isGroup,
bool isUniform,
643 std::optional<uint32_t> clusterSize) {
645 auto scope = mlir::spirv::ScopeAttr::get(builder.
getContext(),
646 isGroup ? spirv::Scope::Workgroup
647 : spirv::Scope::Subgroup);
648 auto groupOp = spirv::GroupOperationAttr::get(
650 ? spirv::GroupOperation::ClusteredReduce
651 : spirv::GroupOperation::Reduce);
653 return UniformOp::create(builder, loc, type, scope, groupOp, arg)
657 Value clusterSizeValue;
658 if (clusterSize.has_value())
659 clusterSizeValue = spirv::ConstantOp::create(
663 return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
668static std::optional<Value>
670 gpu::AllReduceOperation opType,
bool isGroup,
671 bool isUniform, std::optional<uint32_t> clusterSize) {
672 enum class ElemType { Float, Boolean, Integer };
674 std::optional<uint32_t>);
676 gpu::AllReduceOperation kind;
682 ElemType elementType;
683 if (isa<FloatType>(type)) {
684 elementType = ElemType::Float;
685 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
686 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
697 using ReduceType = gpu::AllReduceOperation;
698 const OpHandler handlers[] = {
699 {ReduceType::ADD, ElemType::Integer,
701 spirv::GroupNonUniformIAddOp>},
702 {ReduceType::ADD, ElemType::Float,
704 spirv::GroupNonUniformFAddOp>},
705 {ReduceType::MUL, ElemType::Integer,
707 spirv::GroupNonUniformIMulOp>},
708 {ReduceType::MUL, ElemType::Float,
710 spirv::GroupNonUniformFMulOp>},
711 {ReduceType::MINUI, ElemType::Integer,
713 spirv::GroupNonUniformUMinOp>},
714 {ReduceType::MINSI, ElemType::Integer,
716 spirv::GroupNonUniformSMinOp>},
717 {ReduceType::MINNUMF, ElemType::Float,
719 spirv::GroupNonUniformFMinOp>},
720 {ReduceType::MAXUI, ElemType::Integer,
722 spirv::GroupNonUniformUMaxOp>},
723 {ReduceType::MAXSI, ElemType::Integer,
725 spirv::GroupNonUniformSMaxOp>},
726 {ReduceType::MAXNUMF, ElemType::Float,
728 spirv::GroupNonUniformFMaxOp>},
729 {ReduceType::MINIMUMF, ElemType::Float,
731 spirv::GroupNonUniformFMinOp>},
732 {ReduceType::MAXIMUMF, ElemType::Float,
734 spirv::GroupNonUniformFMaxOp>}};
736 for (
const OpHandler &handler : handlers)
737 if (handler.kind == opType && elementType == handler.elemType)
738 return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
745 :
public OpConversionPattern<gpu::AllReduceOp> {
751 ConversionPatternRewriter &rewriter)
const override {
752 auto opType = op.getOp();
761 true, op.getUniform(), std::nullopt);
765 rewriter.replaceOp(op, *
result);
772 :
public OpConversionPattern<gpu::SubgroupReduceOp> {
778 ConversionPatternRewriter &rewriter)
const override {
779 if (op.getClusterStride() > 1) {
780 return rewriter.notifyMatchFailure(
781 op,
"lowering for cluster stride > 1 is not implemented");
784 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
785 return rewriter.notifyMatchFailure(op,
"reduction type is not a scalar");
788 rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
789 false, adaptor.getUniform(), op.getClusterSize());
793 rewriter.replaceOp(op, *
result);
802static std::string
makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
808 name = (prefix + llvm::Twine(number++)).str();
809 }
while (moduleOp.lookupSymbol(name));
816LogicalResult GPUPrintfConversion::matchAndRewrite(
817 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter)
const {
820 Location loc = gpuPrintfOp.getLoc();
822 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
829 std::string globalVarName =
makeVarName(moduleOp, llvm::Twine(
"printfMsg"));
830 spirv::GlobalVariableOp globalVar;
832 IntegerType i8Type = rewriter.getI8Type();
833 IntegerType i32Type = rewriter.getI32Type();
840 auto createSpecConstant = [&](
unsigned value) {
841 auto attr = rewriter.getI8IntegerAttr(value);
842 std::string specCstName =
843 makeVarName(moduleOp, llvm::Twine(globalVarName) +
"_sc");
845 return spirv::SpecConstantOp::create(
846 rewriter, loc, rewriter.getStringAttr(specCstName), attr);
852 ConversionPatternRewriter::InsertionGuard guard(rewriter);
855 rewriter.setInsertionPointToStart(
862 llvm::SmallString<20> formatString(adaptor.getFormat());
863 formatString.push_back(
'\0');
864 SmallVector<Attribute, 4> constituents;
865 for (
char c : formatString) {
866 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
867 constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
871 size_t contentSize = constituents.size();
873 spirv::SpecConstantCompositeOp specCstComposite;
876 std::string specCstCompositeName =
877 (llvm::Twine(globalVarName) +
"_scc").str();
879 specCstComposite = spirv::SpecConstantCompositeOp::create(
880 rewriter, loc, TypeAttr::get(globalType),
881 rewriter.getStringAttr(specCstCompositeName),
882 rewriter.getArrayAttr(constituents));
885 globalType, spirv::StorageClass::UniformConstant);
890 globalVar = spirv::GlobalVariableOp::create(
891 rewriter, loc, ptrType, globalVarName,
894 globalVar->setAttr(
"Constant", rewriter.getUnitAttr());
898 Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
899 Value fmtStr = spirv::BitcastOp::create(
905 auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
907 spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
912 rewriter.eraseOp(gpuPrintfOp);
924 GPUBarrierConversion, GPUBallotConversion, GPUFuncOpConversion,
925 GPUModuleConversion, GPUReturnOpConversion, GPUShuffleConversion,
926 GPURotateConversion, GPUSubgroupBroadcastConversion,
927 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
928 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
929 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
930 LaunchConfigConversion<gpu::ThreadIdOp,
931 spirv::BuiltIn::LocalInvocationId>,
932 LaunchConfigConversion<gpu::GlobalIdOp,
933 spirv::BuiltIn::GlobalInvocationId>,
934 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
935 spirv::BuiltIn::SubgroupId>,
936 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
937 spirv::BuiltIn::NumSubgroups>,
938 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
939 spirv::BuiltIn::SubgroupSize>,
940 SingleDimLaunchConfigConversion<
941 gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static constexpr const char kSPIRVModule[]
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
static ArrayType get(Type elementType, unsigned elementCount)
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.