32 template <
typename SourceOp, spirv::BuiltIn builtin>
38 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
44 template <
typename SourceOp, spirv::BuiltIn builtin>
50 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
65 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
75 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
88 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
99 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
109 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
119 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
129 matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
138 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
148 template <
typename SourceOp, spirv::BuiltIn builtin>
149 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
150 SourceOp op,
typename SourceOp::Adaptor adaptor,
152 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
153 Type indexType = typeConverter->getIndexType();
167 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
172 Value dim = rewriter.
create<spirv::CompositeExtractOp>(
173 op.getLoc(), builtinType, vector,
175 if (forShader && builtinType != indexType)
176 dim = rewriter.
create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
181 template <
typename SourceOp, spirv::BuiltIn builtin>
183 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
184 SourceOp op,
typename SourceOp::Adaptor adaptor,
186 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
187 Type indexType = typeConverter->getIndexType();
200 if (i32Type != indexType)
201 builtinValue = rewriter.
create<spirv::UConvertOp>(op.getLoc(), indexType,
207 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
208 gpu::BlockDimOp op, OpAdaptor adaptor,
211 if (!workGroupSizeAttr)
215 workGroupSizeAttr.
asArrayRef()[
static_cast<int32_t
>(op.getDimension())];
217 getTypeConverter()->convertType(op.getResult().getType());
233 spirv::EntryPointABIAttr entryPointInfo,
235 auto fnType = funcOp.getFunctionType();
236 if (fnType.getNumResults()) {
237 funcOp.emitError(
"SPIR-V lowering only supports entry functions"
238 "with no return values right now");
241 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
243 "lowering as entry functions requires ABI info for all arguments "
252 for (
const auto &argType :
253 enumerate(funcOp.getFunctionType().getInputs())) {
254 auto convertedType = typeConverter.
convertType(argType.value());
257 signatureConverter.
addInputs(argType.index(), convertedType);
260 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
261 funcOp.getLoc(), funcOp.getName(),
263 for (
const auto &namedAttr : funcOp->getAttrs()) {
264 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
267 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
273 &signatureConverter)))
279 for (
auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
280 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
296 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
302 std::optional<spirv::StorageClass> sc;
303 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
304 sc = spirv::StorageClass::StorageBuffer;
311 LogicalResult GPUFuncOpConversion::matchAndRewrite(
312 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
314 if (!gpu::GPUDialect::isKernel(funcOp))
317 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
322 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
328 "match failure: missing 'spirv.interface_var_abi' attribute at "
333 argABI.push_back(abiAttr);
338 if (!entryPointAttr) {
340 "match failure: missing 'spirv.entry_point_abi' attribute");
344 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
347 newFuncOp->removeAttr(
348 rewriter.
getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
356 LogicalResult GPUModuleConversion::matchAndRewrite(
357 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
359 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
362 targetEnv, typeConverter->getOptions().use64bitIndex);
364 if (failed(memoryModel))
365 return moduleOp.emitRemark(
366 "cannot deduce memory model from 'spirv.target_env'");
369 std::string spvModuleName = (
kSPIRVModule + moduleOp.getName()).str();
370 auto spvModule = rewriter.
create<spirv::ModuleOp>(
371 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
372 StringRef(spvModuleName));
375 Region &spvModuleRegion = spvModule.getRegion();
377 spvModuleRegion.begin());
397 LogicalResult GPUReturnOpConversion::matchAndRewrite(
398 gpu::ReturnOp returnOp, OpAdaptor adaptor,
400 if (!adaptor.getOperands().empty())
411 LogicalResult GPUBarrierConversion::matchAndRewrite(
412 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
419 context, spirv::MemorySemantics::WorkgroupMemory |
420 spirv::MemorySemantics::AcquireRelease);
430 LogicalResult GPUShuffleConversion::matchAndRewrite(
431 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
436 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
439 IntegerAttr widthAttr;
443 shuffleOp,
"shuffle width and target subgroup size mismatch");
445 assert(!adaptor.getOffset().getType().isSignedInteger() &&
446 "shuffle offset must be a signless/unsigned integer");
449 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
453 switch (shuffleOp.getMode()) {
454 case gpu::ShuffleMode::XOR: {
455 result = rewriter.
create<spirv::GroupNonUniformShuffleXorOp>(
456 loc, scope, adaptor.getValue(), adaptor.getOffset());
457 validVal = spirv::ConstantOp::getOne(rewriter.
getI1Type(),
458 shuffleOp.getLoc(), rewriter);
461 case gpu::ShuffleMode::IDX: {
462 result = rewriter.
create<spirv::GroupNonUniformShuffleOp>(
463 loc, scope, adaptor.getValue(), adaptor.getOffset());
464 validVal = spirv::ConstantOp::getOne(rewriter.
getI1Type(),
465 shuffleOp.getLoc(), rewriter);
468 case gpu::ShuffleMode::DOWN: {
469 result = rewriter.
create<spirv::GroupNonUniformShuffleDownOp>(
470 loc, scope, adaptor.getValue(), adaptor.getOffset());
472 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc, widthAttr);
474 rewriter.
create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
475 validVal = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
476 resultLaneId, adaptor.getWidth());
479 case gpu::ShuffleMode::UP: {
480 result = rewriter.
create<spirv::GroupNonUniformShuffleUpOp>(
481 loc, scope, adaptor.getValue(), adaptor.getOffset());
483 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc, widthAttr);
485 rewriter.
create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
487 validVal = rewriter.
create<arith::CmpIOp>(
488 loc, arith::CmpIPredicate::sge, resultLaneId,
489 rewriter.
create<arith::ConstantOp>(
495 rewriter.
replaceOp(shuffleOp, {result, validVal});
503 LogicalResult GPURotateConversion::matchAndRewrite(
504 gpu::RotateOp rotateOp, OpAdaptor adaptor,
507 getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
510 IntegerAttr widthAttr;
515 "rotate width is not a constant or larger than target subgroup size");
518 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
519 Value rotateResult = rewriter.
create<spirv::GroupNonUniformRotateKHROp>(
520 loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
522 if (widthAttr.getValue().getZExtValue() ==
subgroupSize) {
523 validVal = spirv::ConstantOp::getOne(rewriter.
getI1Type(), loc, rewriter);
525 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc, widthAttr);
526 validVal = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
527 laneId, adaptor.getWidth());
530 rewriter.
replaceOp(rotateOp, {rotateResult, validVal});
538 template <
typename UniformOp,
typename NonUniformOp>
540 Value arg,
bool isGroup,
bool isUniform,
541 std::optional<uint32_t> clusterSize) {
544 isGroup ? spirv::Scope::Workgroup
545 : spirv::Scope::Subgroup);
548 ? spirv::GroupOperation::ClusteredReduce
549 : spirv::GroupOperation::Reduce);
551 return builder.
create<UniformOp>(loc, type, scope, groupOp, arg)
555 Value clusterSizeValue;
556 if (clusterSize.has_value())
557 clusterSizeValue = builder.
create<spirv::ConstantOp>(
562 .
create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
566 static std::optional<Value>
568 gpu::AllReduceOperation opType,
bool isGroup,
569 bool isUniform, std::optional<uint32_t> clusterSize) {
570 enum class ElemType { Float, Boolean, Integer };
572 std::optional<uint32_t>);
574 gpu::AllReduceOperation
kind;
580 ElemType elementType;
581 if (isa<FloatType>(type)) {
582 elementType = ElemType::Float;
583 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
584 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
595 using ReduceType = gpu::AllReduceOperation;
596 const OpHandler handlers[] = {
597 {ReduceType::ADD, ElemType::Integer,
599 spirv::GroupNonUniformIAddOp>},
600 {ReduceType::ADD, ElemType::Float,
602 spirv::GroupNonUniformFAddOp>},
603 {ReduceType::MUL, ElemType::Integer,
605 spirv::GroupNonUniformIMulOp>},
606 {ReduceType::MUL, ElemType::Float,
608 spirv::GroupNonUniformFMulOp>},
611 spirv::GroupNonUniformUMinOp>},
612 {ReduceType::MINSI, ElemType::Integer,
614 spirv::GroupNonUniformSMinOp>},
615 {ReduceType::MINNUMF, ElemType::Float,
617 spirv::GroupNonUniformFMinOp>},
618 {ReduceType::MAXUI, ElemType::Integer,
620 spirv::GroupNonUniformUMaxOp>},
621 {ReduceType::MAXSI, ElemType::Integer,
623 spirv::GroupNonUniformSMaxOp>},
624 {ReduceType::MAXNUMF, ElemType::Float,
626 spirv::GroupNonUniformFMaxOp>},
627 {ReduceType::MINIMUMF, ElemType::Float,
629 spirv::GroupNonUniformFMinOp>},
630 {ReduceType::MAXIMUMF, ElemType::Float,
632 spirv::GroupNonUniformFMaxOp>}};
634 for (
const OpHandler &handler : handlers)
635 if (handler.kind == opType && elementType == handler.elemType)
636 return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
650 auto opType = op.getOp();
659 true, op.getUniform(), std::nullopt);
677 if (op.getClusterStride() > 1) {
679 op,
"lowering for cluster stride > 1 is not implemented");
682 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
686 rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
687 false, adaptor.getUniform(), op.getClusterSize());
700 static std::string
makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
706 name = (prefix + llvm::Twine(number++)).str();
707 }
while (moduleOp.lookupSymbol(name));
714 LogicalResult GPUPrintfConversion::matchAndRewrite(
715 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
718 Location loc = gpuPrintfOp.getLoc();
720 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
727 std::string globalVarName =
makeVarName(moduleOp, llvm::Twine(
"printfMsg"));
728 spirv::GlobalVariableOp globalVar;
730 IntegerType i8Type = rewriter.
getI8Type();
738 auto createSpecConstant = [&](
unsigned value) {
740 std::string specCstName =
741 makeVarName(moduleOp, llvm::Twine(globalVarName) +
"_sc");
743 return rewriter.
create<spirv::SpecConstantOp>(
750 ConversionPatternRewriter::InsertionGuard guard(rewriter);
761 formatString.push_back(
'\0');
763 for (
char c : formatString) {
764 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
769 size_t contentSize = constituents.size();
771 spirv::SpecConstantCompositeOp specCstComposite;
774 std::string specCstCompositeName =
775 (llvm::Twine(globalVarName) +
"_scc").str();
777 specCstComposite = rewriter.
create<spirv::SpecConstantCompositeOp>(
783 globalType, spirv::StorageClass::UniformConstant);
788 globalVar = rewriter.
create<spirv::GlobalVariableOp>(
795 Value globalPtr = rewriter.
create<spirv::AddressOfOp>(loc, globalVar);
802 auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
804 rewriter.
create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
821 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
822 GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
823 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
824 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
825 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
826 LaunchConfigConversion<gpu::ThreadIdOp,
827 spirv::BuiltIn::LocalInvocationId>,
828 LaunchConfigConversion<gpu::GlobalIdOp,
829 spirv::BuiltIn::GlobalInvocationId>,
830 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
831 spirv::BuiltIn::SubgroupId>,
832 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
833 spirv::BuiltIn::NumSubgroups>,
834 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
835 spirv::BuiltIn::SubgroupSize>,
836 SingleDimLaunchConfigConversion<
837 gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static constexpr const char kSPIRVModule[]
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1221::ArityGroupAndKind::Kind kind
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Block represents an ordered list of Operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
IntegerAttr getI8IntegerAttr(int8_t value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
static ArrayType get(Type elementType, unsigned elementCount)
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
constexpr unsigned subgroupSize
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet & patterns
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.