35 template <
typename SourceOp, spirv::BuiltIn builtin>
41 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
47 template <
typename SourceOp, spirv::BuiltIn builtin>
53 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
67 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
77 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
90 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
94 class GPUModuleEndConversion final
100 matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
114 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
124 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
134 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
144 template <
typename SourceOp, spirv::BuiltIn builtin>
145 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
146 SourceOp op,
typename SourceOp::Adaptor adaptor,
148 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
149 Type indexType = typeConverter->getIndexType();
163 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
168 Value dim = rewriter.
create<spirv::CompositeExtractOp>(
169 op.
getLoc(), builtinType, vector,
171 if (forShader && builtinType != indexType)
172 dim = rewriter.
create<spirv::UConvertOp>(op.
getLoc(), indexType, dim);
177 template <
typename SourceOp, spirv::BuiltIn builtin>
179 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
180 SourceOp op,
typename SourceOp::Adaptor adaptor,
182 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
183 Type indexType = typeConverter->getIndexType();
196 if (i32Type != indexType)
197 builtinValue = rewriter.
create<spirv::UConvertOp>(op.
getLoc(), indexType,
204 gpu::BlockDimOp op, OpAdaptor adaptor,
207 if (!workGroupSizeAttr)
211 workGroupSizeAttr.
asArrayRef()[
static_cast<int32_t
>(op.getDimension())];
229 spirv::EntryPointABIAttr entryPointInfo,
231 auto fnType = funcOp.getFunctionType();
232 if (fnType.getNumResults()) {
233 funcOp.emitError(
"SPIR-V lowering only supports entry functions"
234 "with no return values right now");
237 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
239 "lowering as entry functions requires ABI info for all arguments "
248 for (
const auto &argType :
249 enumerate(funcOp.getFunctionType().getInputs())) {
250 auto convertedType = typeConverter.
convertType(argType.value());
253 signatureConverter.
addInputs(argType.index(), convertedType);
256 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
257 funcOp.getLoc(), funcOp.getName(),
260 for (
const auto &namedAttr : funcOp->getAttrs()) {
261 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
264 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
270 &signatureConverter)))
276 for (
auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
277 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
293 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
299 std::optional<spirv::StorageClass> sc;
300 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
301 sc = spirv::StorageClass::StorageBuffer;
309 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
311 if (!gpu::GPUDialect::isKernel(funcOp))
314 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
319 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
325 "match failure: missing 'spirv.interface_var_abi' attribute at "
330 argABI.push_back(abiAttr);
335 if (!entryPointAttr) {
337 "match failure: missing 'spirv.entry_point_abi' attribute");
341 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
344 newFuncOp->removeAttr(
345 rewriter.
getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
354 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
356 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
359 targetEnv, typeConverter->getOptions().use64bitIndex);
362 return moduleOp.emitRemark(
363 "cannot deduce memory model from 'spirv.target_env'");
366 std::string spvModuleName = (
kSPIRVModule + moduleOp.getName()).str();
367 auto spvModule = rewriter.
create<spirv::ModuleOp>(
368 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
369 StringRef(spvModuleName));
372 Region &spvModuleRegion = spvModule.getRegion();
374 spvModuleRegion.begin());
395 gpu::ReturnOp returnOp, OpAdaptor adaptor,
397 if (!adaptor.getOperands().empty())
409 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
416 context, spirv::MemorySemantics::WorkgroupMemory |
417 spirv::MemorySemantics::AcquireRelease);
428 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
433 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
434 unsigned subgroupSize =
436 IntegerAttr widthAttr;
438 widthAttr.getValue().getZExtValue() != subgroupSize)
440 shuffleOp,
"shuffle width and target subgroup size mismatch");
444 shuffleOp.getLoc(), rewriter);
445 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
448 switch (shuffleOp.getMode()) {
449 case gpu::ShuffleMode::XOR:
450 result = rewriter.
create<spirv::GroupNonUniformShuffleXorOp>(
451 loc, scope, adaptor.getValue(), adaptor.getOffset());
453 case gpu::ShuffleMode::IDX:
454 result = rewriter.
create<spirv::GroupNonUniformShuffleOp>(
455 loc, scope, adaptor.getValue(), adaptor.getOffset());
461 rewriter.
replaceOp(shuffleOp, {result, trueVal});
469 template <
typename UniformOp,
typename NonUniformOp>
471 Value arg,
bool isGroup,
bool isUniform) {
474 isGroup ? spirv::Scope::Workgroup
475 : spirv::Scope::Subgroup);
477 spirv::GroupOperation::Reduce);
479 return builder.
create<UniformOp>(loc, type, scope, groupOp, arg)
482 return builder.
create<NonUniformOp>(loc, type, scope, groupOp, arg,
Value{})
488 gpu::AllReduceOperation opType,
489 bool isGroup,
bool isUniform) {
490 enum class ElemType { Float, Boolean, Integer };
493 gpu::AllReduceOperation kind;
499 ElemType elementType;
500 if (isa<FloatType>(type)) {
501 elementType = ElemType::Float;
502 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
503 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
514 using ReduceType = gpu::AllReduceOperation;
515 const OpHandler handlers[] = {
516 {ReduceType::ADD, ElemType::Integer,
518 spirv::GroupNonUniformIAddOp>},
519 {ReduceType::ADD, ElemType::Float,
521 spirv::GroupNonUniformFAddOp>},
522 {ReduceType::MUL, ElemType::Integer,
524 spirv::GroupNonUniformIMulOp>},
525 {ReduceType::MUL, ElemType::Float,
527 spirv::GroupNonUniformFMulOp>},
530 spirv::GroupNonUniformUMinOp>},
531 {ReduceType::MINSI, ElemType::Integer,
533 spirv::GroupNonUniformSMinOp>},
534 {ReduceType::MINNUMF, ElemType::Float,
536 spirv::GroupNonUniformFMinOp>},
537 {ReduceType::MAXUI, ElemType::Integer,
539 spirv::GroupNonUniformUMaxOp>},
540 {ReduceType::MAXSI, ElemType::Integer,
542 spirv::GroupNonUniformSMaxOp>},
543 {ReduceType::MAXNUMF, ElemType::Float,
545 spirv::GroupNonUniformFMaxOp>},
546 {ReduceType::MINIMUMF, ElemType::Float,
548 spirv::GroupNonUniformFMinOp>},
549 {ReduceType::MAXIMUMF, ElemType::Float,
551 spirv::GroupNonUniformFMaxOp>}};
553 for (
const OpHandler &handler : handlers)
554 if (handler.kind == opType && elementType == handler.elemType)
555 return handler.func(builder, loc, arg, isGroup, isUniform);
569 auto opType = op.getOp();
578 true, op.getUniform());
596 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
601 false, adaptor.getUniform());
617 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
618 GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
619 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
620 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
621 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
622 LaunchConfigConversion<gpu::ThreadIdOp,
623 spirv::BuiltIn::LocalInvocationId>,
624 LaunchConfigConversion<gpu::GlobalIdOp,
625 spirv::BuiltIn::GlobalInvocationId>,
626 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
627 spirv::BuiltIn::SubgroupId>,
628 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
629 spirv::BuiltIn::NumSubgroups>,
630 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
631 spirv::BuiltIn::SubgroupSize>,
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform)
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform)
static constexpr const char kSPIRVModule[]
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static MLIRContext * getContext(OpFoldResult val)
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
An attribute that specifies the information regarding the interface variable: descriptor set,...
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.