34 template <
typename SourceOp, spirv::BuiltIn builtin>
40 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
46 template <
typename SourceOp, spirv::BuiltIn builtin>
52 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
66 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
76 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
89 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
93 class GPUModuleEndConversion final
99 matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
113 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
123 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
133 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
143 template <
typename SourceOp, spirv::BuiltIn builtin>
144 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
145 SourceOp op,
typename SourceOp::Adaptor adaptor,
147 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
148 Type indexType = typeConverter->getIndexType();
162 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
167 Value dim = rewriter.
create<spirv::CompositeExtractOp>(
168 op.
getLoc(), builtinType, vector,
170 if (forShader && builtinType != indexType)
171 dim = rewriter.
create<spirv::UConvertOp>(op.
getLoc(), indexType, dim);
176 template <
typename SourceOp, spirv::BuiltIn builtin>
178 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
179 SourceOp op,
typename SourceOp::Adaptor adaptor,
181 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
182 Type indexType = typeConverter->getIndexType();
195 if (i32Type != indexType)
196 builtinValue = rewriter.
create<spirv::UConvertOp>(op.
getLoc(), indexType,
202 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
203 gpu::BlockDimOp op, OpAdaptor adaptor,
206 if (!workGroupSizeAttr)
210 workGroupSizeAttr.
asArrayRef()[
static_cast<int32_t
>(op.getDimension())];
228 spirv::EntryPointABIAttr entryPointInfo,
230 auto fnType = funcOp.getFunctionType();
231 if (fnType.getNumResults()) {
232 funcOp.emitError(
"SPIR-V lowering only supports entry functions"
233 "with no return values right now");
236 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
238 "lowering as entry functions requires ABI info for all arguments "
247 for (
const auto &argType :
248 enumerate(funcOp.getFunctionType().getInputs())) {
249 auto convertedType = typeConverter.
convertType(argType.value());
252 signatureConverter.
addInputs(argType.index(), convertedType);
255 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
256 funcOp.getLoc(), funcOp.getName(),
259 for (
const auto &namedAttr : funcOp->getAttrs()) {
260 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
263 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
269 &signatureConverter)))
275 for (
auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
276 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
292 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
298 std::optional<spirv::StorageClass> sc;
299 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
300 sc = spirv::StorageClass::StorageBuffer;
307 LogicalResult GPUFuncOpConversion::matchAndRewrite(
308 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
310 if (!gpu::GPUDialect::isKernel(funcOp))
313 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
318 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
324 "match failure: missing 'spirv.interface_var_abi' attribute at "
329 argABI.push_back(abiAttr);
334 if (!entryPointAttr) {
336 "match failure: missing 'spirv.entry_point_abi' attribute");
340 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
343 newFuncOp->removeAttr(
344 rewriter.
getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
352 LogicalResult GPUModuleConversion::matchAndRewrite(
353 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
355 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
358 targetEnv, typeConverter->getOptions().use64bitIndex);
360 if (failed(memoryModel))
361 return moduleOp.emitRemark(
362 "cannot deduce memory model from 'spirv.target_env'");
365 std::string spvModuleName = (
kSPIRVModule + moduleOp.getName()).str();
366 auto spvModule = rewriter.
create<spirv::ModuleOp>(
367 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
368 StringRef(spvModuleName));
371 Region &spvModuleRegion = spvModule.getRegion();
373 spvModuleRegion.begin());
393 LogicalResult GPUReturnOpConversion::matchAndRewrite(
394 gpu::ReturnOp returnOp, OpAdaptor adaptor,
396 if (!adaptor.getOperands().empty())
407 LogicalResult GPUBarrierConversion::matchAndRewrite(
408 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
415 context, spirv::MemorySemantics::WorkgroupMemory |
416 spirv::MemorySemantics::AcquireRelease);
426 LogicalResult GPUShuffleConversion::matchAndRewrite(
427 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
432 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
433 unsigned subgroupSize =
435 IntegerAttr widthAttr;
437 widthAttr.getValue().getZExtValue() != subgroupSize)
439 shuffleOp,
"shuffle width and target subgroup size mismatch");
443 shuffleOp.getLoc(), rewriter);
444 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
447 switch (shuffleOp.getMode()) {
448 case gpu::ShuffleMode::XOR:
449 result = rewriter.
create<spirv::GroupNonUniformShuffleXorOp>(
450 loc, scope, adaptor.getValue(), adaptor.getOffset());
452 case gpu::ShuffleMode::IDX:
453 result = rewriter.
create<spirv::GroupNonUniformShuffleOp>(
454 loc, scope, adaptor.getValue(), adaptor.getOffset());
460 rewriter.
replaceOp(shuffleOp, {result, trueVal});
468 template <
typename UniformOp,
typename NonUniformOp>
470 Value arg,
bool isGroup,
bool isUniform) {
473 isGroup ? spirv::Scope::Workgroup
474 : spirv::Scope::Subgroup);
476 spirv::GroupOperation::Reduce);
478 return builder.
create<UniformOp>(loc, type, scope, groupOp, arg)
481 return builder.
create<NonUniformOp>(loc, type, scope, groupOp, arg,
Value{})
487 gpu::AllReduceOperation opType,
488 bool isGroup,
bool isUniform) {
489 enum class ElemType { Float, Boolean, Integer };
492 gpu::AllReduceOperation kind;
498 ElemType elementType;
499 if (isa<FloatType>(type)) {
500 elementType = ElemType::Float;
501 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
502 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
513 using ReduceType = gpu::AllReduceOperation;
514 const OpHandler handlers[] = {
515 {ReduceType::ADD, ElemType::Integer,
517 spirv::GroupNonUniformIAddOp>},
518 {ReduceType::ADD, ElemType::Float,
520 spirv::GroupNonUniformFAddOp>},
521 {ReduceType::MUL, ElemType::Integer,
523 spirv::GroupNonUniformIMulOp>},
524 {ReduceType::MUL, ElemType::Float,
526 spirv::GroupNonUniformFMulOp>},
529 spirv::GroupNonUniformUMinOp>},
530 {ReduceType::MINSI, ElemType::Integer,
532 spirv::GroupNonUniformSMinOp>},
533 {ReduceType::MINNUMF, ElemType::Float,
535 spirv::GroupNonUniformFMinOp>},
536 {ReduceType::MAXUI, ElemType::Integer,
538 spirv::GroupNonUniformUMaxOp>},
539 {ReduceType::MAXSI, ElemType::Integer,
541 spirv::GroupNonUniformSMaxOp>},
542 {ReduceType::MAXNUMF, ElemType::Float,
544 spirv::GroupNonUniformFMaxOp>},
545 {ReduceType::MINIMUMF, ElemType::Float,
547 spirv::GroupNonUniformFMinOp>},
548 {ReduceType::MAXIMUMF, ElemType::Float,
550 spirv::GroupNonUniformFMaxOp>}};
552 for (
const OpHandler &handler : handlers)
553 if (handler.kind == opType && elementType == handler.elemType)
554 return handler.func(builder, loc, arg, isGroup, isUniform);
568 auto opType = op.getOp();
577 true, op.getUniform());
595 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
600 false, adaptor.getUniform());
616 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
617 GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
618 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
619 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
620 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
621 LaunchConfigConversion<gpu::ThreadIdOp,
622 spirv::BuiltIn::LocalInvocationId>,
623 LaunchConfigConversion<gpu::GlobalIdOp,
624 spirv::BuiltIn::GlobalInvocationId>,
625 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
626 spirv::BuiltIn::SubgroupId>,
627 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
628 spirv::BuiltIn::NumSubgroups>,
629 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
630 spirv::BuiltIn::SubgroupSize>,
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform)
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform)
static constexpr const char kSPIRVModule[]
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static MLIRContext * getContext(OpFoldResult val)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
An attribute that specifies the information regarding the interface variable: descriptor set,...
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.