33 template <
typename SourceOp, spirv::BuiltIn builtin>
39 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
45 template <
typename SourceOp, spirv::BuiltIn builtin>
51 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
65 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
75 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
88 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
92 class GPUModuleEndConversion final
98 matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
112 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
122 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
132 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
142 template <
typename SourceOp, spirv::BuiltIn builtin>
143 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
144 SourceOp op,
typename SourceOp::Adaptor adaptor,
146 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
147 auto indexType = typeConverter->getIndexType();
153 op, indexType, spirvBuiltin,
158 template <
typename SourceOp, spirv::BuiltIn builtin>
160 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
161 SourceOp op,
typename SourceOp::Adaptor adaptor,
163 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
164 auto indexType = typeConverter->getIndexType();
173 gpu::BlockDimOp op, OpAdaptor adaptor,
176 if (!workGroupSizeAttr)
180 workGroupSizeAttr.
asArrayRef()[
static_cast<int32_t
>(op.getDimension())];
182 getTypeConverter()->convertType(op.getResult().getType());
186 op, convertedType, IntegerAttr::get(convertedType, val));
198 spirv::EntryPointABIAttr entryPointInfo,
200 auto fnType = funcOp.getFunctionType();
201 if (fnType.getNumResults()) {
202 funcOp.emitError(
"SPIR-V lowering only supports entry functions"
203 "with no return values right now");
206 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
208 "lowering as entry functions requires ABI info for all arguments "
217 for (
const auto &argType :
218 enumerate(funcOp.getFunctionType().getInputs())) {
219 auto convertedType = typeConverter.
convertType(argType.value());
222 signatureConverter.
addInputs(argType.index(), convertedType);
225 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
226 funcOp.getLoc(), funcOp.getName(),
229 for (
const auto &namedAttr : funcOp->getAttrs()) {
230 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
233 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
239 &signatureConverter)))
245 for (
auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
246 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
263 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
269 std::optional<spirv::StorageClass> sc;
270 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
271 sc = spirv::StorageClass::StorageBuffer;
278 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
280 if (!gpu::GPUDialect::isKernel(funcOp))
286 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
292 "match failure: missing 'spirv.interface_var_abi' attribute at "
297 argABI.push_back(abiAttr);
302 if (!entryPointAttr) {
304 "match failure: missing 'spirv.entry_point_abi' attribute");
308 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
311 newFuncOp->removeAttr(
312 rewriter.
getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
321 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
327 return moduleOp.emitRemark(
"match failure: could not selected memory model "
328 "based on 'spirv.target_env'");
331 std::string spvModuleName = (
kSPIRVModule + moduleOp.getName()).str();
332 auto spvModule = rewriter.
create<spirv::ModuleOp>(
333 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
334 StringRef(spvModuleName));
337 Region &spvModuleRegion = spvModule.getRegion();
339 spvModuleRegion.begin());
360 gpu::ReturnOp returnOp, OpAdaptor adaptor,
362 if (!adaptor.getOperands().empty())
374 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
378 auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
380 auto memorySemantics = spirv::MemorySemanticsAttr::get(
381 context, spirv::MemorySemantics::WorkgroupMemory |
382 spirv::MemorySemantics::AcquireRelease);
393 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
398 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
399 unsigned subgroupSize =
401 IntegerAttr widthAttr;
403 widthAttr.getValue().getZExtValue() != subgroupSize)
405 shuffleOp,
"shuffle width and target subgroup size mismatch");
409 shuffleOp.getLoc(), rewriter);
410 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
413 switch (shuffleOp.getMode()) {
414 case gpu::ShuffleMode::XOR:
415 result = rewriter.
create<spirv::GroupNonUniformShuffleXorOp>(
416 loc, scope, adaptor.getValue(), adaptor.getOffset());
418 case gpu::ShuffleMode::IDX:
419 result = rewriter.
create<spirv::GroupNonUniformShuffleOp>(
420 loc, scope, adaptor.getValue(), adaptor.getOffset());
426 rewriter.
replaceOp(shuffleOp, {result, trueVal});
434 template <
typename UniformOp,
typename NonUniformOp>
436 Value arg,
bool isGroup,
bool isUniform) {
438 auto scope = mlir::spirv::ScopeAttr::get(builder.
getContext(),
439 isGroup ? spirv::Scope::Workgroup
440 : spirv::Scope::Subgroup);
441 auto groupOp = spirv::GroupOperationAttr::get(builder.
getContext(),
442 spirv::GroupOperation::Reduce);
444 return builder.
create<UniformOp>(loc, type, scope, groupOp, arg)
447 return builder.
create<NonUniformOp>(loc, type, scope, groupOp, arg,
Value{})
453 gpu::AllReduceOperation opType,
454 bool isGroup,
bool isUniform) {
457 gpu::AllReduceOperation type;
463 using MembptrT = FuncT OpHandler::*;
466 handlerPtr = &OpHandler::floatFunc;
467 }
else if (type.
isa<IntegerType>()) {
468 handlerPtr = &OpHandler::intFunc;
473 using ReduceType = gpu::AllReduceOperation;
474 namespace spv = spirv;
475 const OpHandler handlers[] = {
477 &createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
478 &createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
481 spv::GroupNonUniformIMulOp>,
483 spv::GroupNonUniformFMulOp>},
486 for (
auto &handler : handlers)
487 if (handler.type == opType)
488 return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
502 auto opType = op.getOp();
511 true, op.getUniform());
529 auto opType = op.getOp();
532 false, op.getUniform());
548 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
549 GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
550 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
551 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
552 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
553 LaunchConfigConversion<gpu::ThreadIdOp,
554 spirv::BuiltIn::LocalInvocationId>,
555 LaunchConfigConversion<gpu::GlobalIdOp,
556 spirv::BuiltIn::GlobalInvocationId>,
557 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
558 spirv::BuiltIn::SubgroupId>,
559 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
560 spirv::BuiltIn::NumSubgroups>,
561 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
562 spirv::BuiltIn::SubgroupSize>,
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform)
static LogicalResult getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform)
static constexpr const char kSPIRVModule[]
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
An attribute that specifies the information regarding the interface variable: descriptor set,...
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder)
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.