34 template <
typename SourceOp, spirv::BuiltIn builtin>
40 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
46 template <
typename SourceOp, spirv::BuiltIn builtin>
52 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
67 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
77 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
90 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
101 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
111 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
121 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
130 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
140 template <
typename SourceOp, spirv::BuiltIn builtin>
141 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
142 SourceOp op,
typename SourceOp::Adaptor adaptor,
144 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
145 Type indexType = typeConverter->getIndexType();
159 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
164 Value dim = rewriter.
create<spirv::CompositeExtractOp>(
165 op.getLoc(), builtinType, vector,
167 if (forShader && builtinType != indexType)
168 dim = rewriter.
create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
173 template <
typename SourceOp, spirv::BuiltIn builtin>
175 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
176 SourceOp op,
typename SourceOp::Adaptor adaptor,
178 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
179 Type indexType = typeConverter->getIndexType();
192 if (i32Type != indexType)
193 builtinValue = rewriter.
create<spirv::UConvertOp>(op.getLoc(), indexType,
199 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
200 gpu::BlockDimOp op, OpAdaptor adaptor,
203 if (!workGroupSizeAttr)
207 workGroupSizeAttr.
asArrayRef()[
static_cast<int32_t
>(op.getDimension())];
209 getTypeConverter()->convertType(op.getResult().getType());
225 spirv::EntryPointABIAttr entryPointInfo,
227 auto fnType = funcOp.getFunctionType();
228 if (fnType.getNumResults()) {
229 funcOp.emitError(
"SPIR-V lowering only supports entry functions"
230 "with no return values right now");
233 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
235 "lowering as entry functions requires ABI info for all arguments "
244 for (
const auto &argType :
245 enumerate(funcOp.getFunctionType().getInputs())) {
246 auto convertedType = typeConverter.
convertType(argType.value());
249 signatureConverter.
addInputs(argType.index(), convertedType);
252 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
253 funcOp.getLoc(), funcOp.getName(),
256 for (
const auto &namedAttr : funcOp->getAttrs()) {
257 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
260 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
266 &signatureConverter)))
272 for (
auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
273 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
289 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
295 std::optional<spirv::StorageClass> sc;
296 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
297 sc = spirv::StorageClass::StorageBuffer;
304 LogicalResult GPUFuncOpConversion::matchAndRewrite(
305 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
307 if (!gpu::GPUDialect::isKernel(funcOp))
310 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
315 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
321 "match failure: missing 'spirv.interface_var_abi' attribute at "
326 argABI.push_back(abiAttr);
331 if (!entryPointAttr) {
333 "match failure: missing 'spirv.entry_point_abi' attribute");
337 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
340 newFuncOp->removeAttr(
341 rewriter.
getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
349 LogicalResult GPUModuleConversion::matchAndRewrite(
350 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
352 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
355 targetEnv, typeConverter->getOptions().use64bitIndex);
357 if (failed(memoryModel))
358 return moduleOp.emitRemark(
359 "cannot deduce memory model from 'spirv.target_env'");
362 std::string spvModuleName = (
kSPIRVModule + moduleOp.getName()).str();
363 auto spvModule = rewriter.
create<spirv::ModuleOp>(
364 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
365 StringRef(spvModuleName));
368 Region &spvModuleRegion = spvModule.getRegion();
370 spvModuleRegion.begin());
390 LogicalResult GPUReturnOpConversion::matchAndRewrite(
391 gpu::ReturnOp returnOp, OpAdaptor adaptor,
393 if (!adaptor.getOperands().empty())
404 LogicalResult GPUBarrierConversion::matchAndRewrite(
405 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
412 context, spirv::MemorySemantics::WorkgroupMemory |
413 spirv::MemorySemantics::AcquireRelease);
423 LogicalResult GPUShuffleConversion::matchAndRewrite(
424 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
429 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
432 IntegerAttr widthAttr;
436 shuffleOp,
"shuffle width and target subgroup size mismatch");
438 assert(!adaptor.getOffset().getType().isSignedInteger() &&
439 "shuffle offset must be a signless/unsigned integer");
442 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
446 switch (shuffleOp.getMode()) {
447 case gpu::ShuffleMode::XOR: {
448 result = rewriter.
create<spirv::GroupNonUniformShuffleXorOp>(
449 loc, scope, adaptor.getValue(), adaptor.getOffset());
450 validVal = spirv::ConstantOp::getOne(rewriter.
getI1Type(),
451 shuffleOp.getLoc(), rewriter);
454 case gpu::ShuffleMode::IDX: {
455 result = rewriter.
create<spirv::GroupNonUniformShuffleOp>(
456 loc, scope, adaptor.getValue(), adaptor.getOffset());
457 validVal = spirv::ConstantOp::getOne(rewriter.
getI1Type(),
458 shuffleOp.getLoc(), rewriter);
461 case gpu::ShuffleMode::DOWN: {
462 result = rewriter.
create<spirv::GroupNonUniformShuffleDownOp>(
463 loc, scope, adaptor.getValue(), adaptor.getOffset());
465 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc, widthAttr);
467 rewriter.
create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
468 validVal = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
469 resultLaneId, adaptor.getWidth());
472 case gpu::ShuffleMode::UP: {
473 result = rewriter.
create<spirv::GroupNonUniformShuffleUpOp>(
474 loc, scope, adaptor.getValue(), adaptor.getOffset());
476 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc, widthAttr);
478 rewriter.
create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
480 validVal = rewriter.
create<arith::CmpIOp>(
481 loc, arith::CmpIPredicate::sge, resultLaneId,
482 rewriter.
create<arith::ConstantOp>(
488 rewriter.
replaceOp(shuffleOp, {result, validVal});
496 template <
typename UniformOp,
typename NonUniformOp>
498 Value arg,
bool isGroup,
bool isUniform,
499 std::optional<uint32_t> clusterSize) {
502 isGroup ? spirv::Scope::Workgroup
503 : spirv::Scope::Subgroup);
506 ? spirv::GroupOperation::ClusteredReduce
507 : spirv::GroupOperation::Reduce);
509 return builder.
create<UniformOp>(loc, type, scope, groupOp, arg)
513 Value clusterSizeValue;
514 if (clusterSize.has_value())
515 clusterSizeValue = builder.
create<spirv::ConstantOp>(
520 .
create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
524 static std::optional<Value>
526 gpu::AllReduceOperation opType,
bool isGroup,
527 bool isUniform, std::optional<uint32_t> clusterSize) {
528 enum class ElemType { Float, Boolean, Integer };
530 std::optional<uint32_t>);
532 gpu::AllReduceOperation
kind;
538 ElemType elementType;
539 if (isa<FloatType>(type)) {
540 elementType = ElemType::Float;
541 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
542 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
553 using ReduceType = gpu::AllReduceOperation;
554 const OpHandler handlers[] = {
555 {ReduceType::ADD, ElemType::Integer,
557 spirv::GroupNonUniformIAddOp>},
558 {ReduceType::ADD, ElemType::Float,
560 spirv::GroupNonUniformFAddOp>},
561 {ReduceType::MUL, ElemType::Integer,
563 spirv::GroupNonUniformIMulOp>},
564 {ReduceType::MUL, ElemType::Float,
566 spirv::GroupNonUniformFMulOp>},
569 spirv::GroupNonUniformUMinOp>},
570 {ReduceType::MINSI, ElemType::Integer,
572 spirv::GroupNonUniformSMinOp>},
573 {ReduceType::MINNUMF, ElemType::Float,
575 spirv::GroupNonUniformFMinOp>},
576 {ReduceType::MAXUI, ElemType::Integer,
578 spirv::GroupNonUniformUMaxOp>},
579 {ReduceType::MAXSI, ElemType::Integer,
581 spirv::GroupNonUniformSMaxOp>},
582 {ReduceType::MAXNUMF, ElemType::Float,
584 spirv::GroupNonUniformFMaxOp>},
585 {ReduceType::MINIMUMF, ElemType::Float,
587 spirv::GroupNonUniformFMinOp>},
588 {ReduceType::MAXIMUMF, ElemType::Float,
590 spirv::GroupNonUniformFMaxOp>}};
592 for (
const OpHandler &handler : handlers)
593 if (handler.kind == opType && elementType == handler.elemType)
594 return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
608 auto opType = op.getOp();
617 true, op.getUniform(), std::nullopt);
635 if (op.getClusterStride() > 1) {
637 op,
"lowering for cluster stride > 1 is not implemented");
640 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
644 rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
645 false, adaptor.getUniform(), op.getClusterSize());
658 static std::string
makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
664 name = (prefix + llvm::Twine(number++)).str();
665 }
while (moduleOp.lookupSymbol(name));
672 LogicalResult GPUPrintfConversion::matchAndRewrite(
673 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
676 Location loc = gpuPrintfOp.getLoc();
678 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
685 std::string globalVarName =
makeVarName(moduleOp, llvm::Twine(
"printfMsg"));
686 spirv::GlobalVariableOp globalVar;
688 IntegerType i8Type = rewriter.
getI8Type();
696 auto createSpecConstant = [&](
unsigned value) {
698 std::string specCstName =
699 makeVarName(moduleOp, llvm::Twine(globalVarName) +
"_sc");
701 return rewriter.
create<spirv::SpecConstantOp>(
708 ConversionPatternRewriter::InsertionGuard guard(rewriter);
719 formatString.push_back(
'\0');
721 for (
char c : formatString) {
722 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
727 size_t contentSize = constituents.size();
729 spirv::SpecConstantCompositeOp specCstComposite;
732 std::string specCstCompositeName =
733 (llvm::Twine(globalVarName) +
"_scc").str();
735 specCstComposite = rewriter.
create<spirv::SpecConstantCompositeOp>(
741 globalType, spirv::StorageClass::UniformConstant);
746 globalVar = rewriter.
create<spirv::GlobalVariableOp>(
753 Value globalPtr = rewriter.
create<spirv::AddressOfOp>(loc, globalVar);
760 auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
762 rewriter.
create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
779 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
780 GPUReturnOpConversion, GPUShuffleConversion,
781 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
782 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
783 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
784 LaunchConfigConversion<gpu::ThreadIdOp,
785 spirv::BuiltIn::LocalInvocationId>,
786 LaunchConfigConversion<gpu::GlobalIdOp,
787 spirv::BuiltIn::GlobalInvocationId>,
788 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
789 spirv::BuiltIn::SubgroupId>,
790 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
791 spirv::BuiltIn::NumSubgroups>,
792 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
793 spirv::BuiltIn::SubgroupSize>,
794 SingleDimLaunchConfigConversion<
795 gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static constexpr const char kSPIRVModule[]
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
constexpr unsigned subgroupSize
HW dependent constants.
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Block represents an ordered list of Operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
IntegerAttr getI8IntegerAttr(int8_t value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
static ArrayType get(Type elementType, unsigned elementCount)
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet & patterns
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.