33 template <
typename SourceOp, spirv::BuiltIn builtin>
39 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
45 template <
typename SourceOp, spirv::BuiltIn builtin>
51 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
65 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
75 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
88 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
92 class GPUModuleEndConversion final
98 matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
112 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
122 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
132 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
142 template <
typename SourceOp, spirv::BuiltIn builtin>
143 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
144 SourceOp op,
typename SourceOp::Adaptor adaptor,
146 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
147 Type indexType = typeConverter->getIndexType();
161 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
166 Value dim = rewriter.
create<spirv::CompositeExtractOp>(
167 op.
getLoc(), builtinType, vector,
169 if (forShader && builtinType != indexType)
170 dim = rewriter.
create<spirv::UConvertOp>(op.
getLoc(), indexType, dim);
175 template <
typename SourceOp, spirv::BuiltIn builtin>
177 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
178 SourceOp op,
typename SourceOp::Adaptor adaptor,
180 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
181 Type indexType = typeConverter->getIndexType();
194 if (i32Type != indexType)
195 builtinValue = rewriter.
create<spirv::UConvertOp>(op.
getLoc(), indexType,
202 gpu::BlockDimOp op, OpAdaptor adaptor,
205 if (!workGroupSizeAttr)
209 workGroupSizeAttr.
asArrayRef()[
static_cast<int32_t
>(op.getDimension())];
227 spirv::EntryPointABIAttr entryPointInfo,
229 auto fnType = funcOp.getFunctionType();
230 if (fnType.getNumResults()) {
231 funcOp.emitError(
"SPIR-V lowering only supports entry functions"
232 "with no return values right now");
235 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
237 "lowering as entry functions requires ABI info for all arguments "
246 for (
const auto &argType :
247 enumerate(funcOp.getFunctionType().getInputs())) {
248 auto convertedType = typeConverter.
convertType(argType.value());
251 signatureConverter.
addInputs(argType.index(), convertedType);
254 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
255 funcOp.getLoc(), funcOp.getName(),
258 for (
const auto &namedAttr : funcOp->getAttrs()) {
259 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
262 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
268 &signatureConverter)))
274 for (
auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
275 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
291 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
297 std::optional<spirv::StorageClass> sc;
298 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
299 sc = spirv::StorageClass::StorageBuffer;
307 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
309 if (!gpu::GPUDialect::isKernel(funcOp))
312 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
317 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
323 "match failure: missing 'spirv.interface_var_abi' attribute at "
328 argABI.push_back(abiAttr);
333 if (!entryPointAttr) {
335 "match failure: missing 'spirv.entry_point_abi' attribute");
339 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
342 newFuncOp->removeAttr(
343 rewriter.
getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
352 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
354 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
357 targetEnv, typeConverter->getOptions().use64bitIndex);
360 return moduleOp.emitRemark(
361 "cannot deduce memory model from 'spirv.target_env'");
364 std::string spvModuleName = (
kSPIRVModule + moduleOp.getName()).str();
365 auto spvModule = rewriter.
create<spirv::ModuleOp>(
366 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
367 StringRef(spvModuleName));
370 Region &spvModuleRegion = spvModule.getRegion();
372 spvModuleRegion.begin());
393 gpu::ReturnOp returnOp, OpAdaptor adaptor,
395 if (!adaptor.getOperands().empty())
407 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
414 context, spirv::MemorySemantics::WorkgroupMemory |
415 spirv::MemorySemantics::AcquireRelease);
426 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
431 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
432 unsigned subgroupSize =
434 IntegerAttr widthAttr;
436 widthAttr.getValue().getZExtValue() != subgroupSize)
438 shuffleOp,
"shuffle width and target subgroup size mismatch");
442 shuffleOp.getLoc(), rewriter);
443 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
446 switch (shuffleOp.getMode()) {
447 case gpu::ShuffleMode::XOR:
448 result = rewriter.
create<spirv::GroupNonUniformShuffleXorOp>(
449 loc, scope, adaptor.getValue(), adaptor.getOffset());
451 case gpu::ShuffleMode::IDX:
452 result = rewriter.
create<spirv::GroupNonUniformShuffleOp>(
453 loc, scope, adaptor.getValue(), adaptor.getOffset());
459 rewriter.
replaceOp(shuffleOp, {result, trueVal});
467 template <
typename UniformOp,
typename NonUniformOp>
469 Value arg,
bool isGroup,
bool isUniform) {
472 isGroup ? spirv::Scope::Workgroup
473 : spirv::Scope::Subgroup);
475 spirv::GroupOperation::Reduce);
477 return builder.
create<UniformOp>(loc, type, scope, groupOp, arg)
480 return builder.
create<NonUniformOp>(loc, type, scope, groupOp, arg,
Value{})
486 gpu::AllReduceOperation opType,
487 bool isGroup,
bool isUniform) {
490 gpu::AllReduceOperation type;
496 using MembptrT = FuncT OpHandler::*;
498 if (isa<FloatType>(type)) {
499 handlerPtr = &OpHandler::floatFunc;
500 }
else if (isa<IntegerType>(type)) {
501 handlerPtr = &OpHandler::intFunc;
506 using ReduceType = gpu::AllReduceOperation;
507 namespace spv = spirv;
508 const OpHandler handlers[] = {
510 &createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
511 &createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
514 spv::GroupNonUniformIMulOp>,
516 spv::GroupNonUniformFMulOp>},
518 &createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
519 &createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
521 &createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
522 &createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
525 for (
auto &handler : handlers)
526 if (handler.type == opType)
527 return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
541 auto opType = op.getOp();
550 true, op.getUniform());
568 auto opType = op.getOp();
571 false, op.getUniform());
587 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
588 GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
589 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
590 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
591 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
592 LaunchConfigConversion<gpu::ThreadIdOp,
593 spirv::BuiltIn::LocalInvocationId>,
594 LaunchConfigConversion<gpu::GlobalIdOp,
595 spirv::BuiltIn::GlobalInvocationId>,
596 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
597 spirv::BuiltIn::SubgroupId>,
598 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
599 spirv::BuiltIn::NumSubgroups>,
600 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
601 spirv::BuiltIn::SubgroupSize>,
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform)
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform)
static constexpr const char kSPIRVModule[]
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static MLIRContext * getContext(OpFoldResult val)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
An attribute that specifies the information regarding the interface variable: descriptor set,...
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.