32 template <
typename SourceOp, spirv::BuiltIn builtin>
38 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
44 template <
typename SourceOp, spirv::BuiltIn builtin>
50 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
65 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
75 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
88 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
99 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
109 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
119 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
129 matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
138 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
148 template <
typename SourceOp, spirv::BuiltIn builtin>
149 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
150 SourceOp op,
typename SourceOp::Adaptor adaptor,
152 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
153 Type indexType = typeConverter->getIndexType();
167 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
172 Value dim = spirv::CompositeExtractOp::create(
173 rewriter, op.getLoc(), builtinType, vector,
175 if (forShader && builtinType != indexType)
176 dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
181 template <
typename SourceOp, spirv::BuiltIn builtin>
183 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
184 SourceOp op,
typename SourceOp::Adaptor adaptor,
186 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
187 Type indexType = typeConverter->getIndexType();
200 if (i32Type != indexType)
201 builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
207 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
208 gpu::BlockDimOp op, OpAdaptor adaptor,
211 if (!workGroupSizeAttr)
215 workGroupSizeAttr.
asArrayRef()[
static_cast<int32_t
>(op.getDimension())];
217 getTypeConverter()->convertType(op.getResult().getType());
233 spirv::EntryPointABIAttr entryPointInfo,
235 auto fnType = funcOp.getFunctionType();
236 if (fnType.getNumResults()) {
237 funcOp.emitError(
"SPIR-V lowering only supports entry functions"
238 "with no return values right now");
241 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
243 "lowering as entry functions requires ABI info for all arguments "
252 for (
const auto &argType :
253 enumerate(funcOp.getFunctionType().getInputs())) {
254 auto convertedType = typeConverter.
convertType(argType.value());
257 signatureConverter.
addInputs(argType.index(), convertedType);
260 auto newFuncOp = spirv::FuncOp::create(
261 rewriter, funcOp.getLoc(), funcOp.getName(),
263 for (
const auto &namedAttr : funcOp->getAttrs()) {
264 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
267 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
273 &signatureConverter)))
279 for (
auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
280 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
296 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
302 std::optional<spirv::StorageClass> sc;
303 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
304 sc = spirv::StorageClass::StorageBuffer;
311 LogicalResult GPUFuncOpConversion::matchAndRewrite(
312 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
314 if (!gpu::GPUDialect::isKernel(funcOp))
317 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
322 for (
auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
328 "match failure: missing 'spirv.interface_var_abi' attribute at "
333 argABI.push_back(abiAttr);
338 if (!entryPointAttr) {
340 "match failure: missing 'spirv.entry_point_abi' attribute");
344 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
347 newFuncOp->removeAttr(
348 rewriter.
getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
356 LogicalResult GPUModuleConversion::matchAndRewrite(
357 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
359 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
362 targetEnv, typeConverter->getOptions().use64bitIndex);
365 return moduleOp.emitRemark(
366 "cannot deduce memory model from 'spirv.target_env'");
369 std::string spvModuleName = (
kSPIRVModule + moduleOp.getName()).str();
370 auto spvModule = spirv::ModuleOp::create(
371 rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
372 StringRef(spvModuleName));
375 Region &spvModuleRegion = spvModule.getRegion();
377 spvModuleRegion.
begin());
388 if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
390 if (
auto spirvTargetEnvAttr =
391 dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
405 LogicalResult GPUReturnOpConversion::matchAndRewrite(
406 gpu::ReturnOp returnOp, OpAdaptor adaptor,
408 if (!adaptor.getOperands().empty())
419 LogicalResult GPUBarrierConversion::matchAndRewrite(
420 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
427 context, spirv::MemorySemantics::WorkgroupMemory |
428 spirv::MemorySemantics::AcquireRelease);
438 LogicalResult GPUShuffleConversion::matchAndRewrite(
439 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
444 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
447 IntegerAttr widthAttr;
451 shuffleOp,
"shuffle width and target subgroup size mismatch");
453 assert(!adaptor.getOffset().getType().isSignedInteger() &&
454 "shuffle offset must be a signless/unsigned integer");
457 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
461 switch (shuffleOp.getMode()) {
462 case gpu::ShuffleMode::XOR: {
463 result = spirv::GroupNonUniformShuffleXorOp::create(
464 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
465 validVal = spirv::ConstantOp::getOne(rewriter.
getI1Type(),
466 shuffleOp.getLoc(), rewriter);
469 case gpu::ShuffleMode::IDX: {
470 result = spirv::GroupNonUniformShuffleOp::create(
471 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
472 validVal = spirv::ConstantOp::getOne(rewriter.
getI1Type(),
473 shuffleOp.getLoc(), rewriter);
476 case gpu::ShuffleMode::DOWN: {
477 result = spirv::GroupNonUniformShuffleDownOp::create(
478 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
480 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
482 arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
483 validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
484 resultLaneId, adaptor.getWidth());
487 case gpu::ShuffleMode::UP: {
488 result = spirv::GroupNonUniformShuffleUpOp::create(
489 rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
491 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
493 arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
495 validVal = arith::CmpIOp::create(
496 rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
497 arith::ConstantOp::create(rewriter, loc, i32Type,
503 rewriter.
replaceOp(shuffleOp, {result, validVal});
511 LogicalResult GPURotateConversion::matchAndRewrite(
512 gpu::RotateOp rotateOp, OpAdaptor adaptor,
515 getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
518 unsigned width = rotateOp.getWidth();
521 rotateOp,
"rotate width is larger than target subgroup size");
524 auto scope = rewriter.
getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
526 arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
528 arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
529 Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
530 rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
533 validVal = spirv::ConstantOp::getOne(rewriter.
getI1Type(), loc, rewriter);
535 IntegerAttr widthAttr = adaptor.getWidthAttr();
536 Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
537 validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
541 rewriter.
replaceOp(rotateOp, {rotateResult, validVal});
549 template <
typename UniformOp,
typename NonUniformOp>
551 Value arg,
bool isGroup,
bool isUniform,
552 std::optional<uint32_t> clusterSize) {
555 isGroup ? spirv::Scope::Workgroup
556 : spirv::Scope::Subgroup);
559 ? spirv::GroupOperation::ClusteredReduce
560 : spirv::GroupOperation::Reduce);
562 return UniformOp::create(builder, loc, type, scope, groupOp, arg)
566 Value clusterSizeValue;
567 if (clusterSize.has_value())
568 clusterSizeValue = spirv::ConstantOp::create(
572 return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
577 static std::optional<Value>
579 gpu::AllReduceOperation opType,
bool isGroup,
580 bool isUniform, std::optional<uint32_t> clusterSize) {
581 enum class ElemType { Float, Boolean, Integer };
583 std::optional<uint32_t>);
585 gpu::AllReduceOperation
kind;
591 ElemType elementType;
592 if (isa<FloatType>(type)) {
593 elementType = ElemType::Float;
594 }
else if (
auto intTy = dyn_cast<IntegerType>(type)) {
595 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
606 using ReduceType = gpu::AllReduceOperation;
607 const OpHandler handlers[] = {
608 {ReduceType::ADD, ElemType::Integer,
610 spirv::GroupNonUniformIAddOp>},
611 {ReduceType::ADD, ElemType::Float,
613 spirv::GroupNonUniformFAddOp>},
614 {ReduceType::MUL, ElemType::Integer,
616 spirv::GroupNonUniformIMulOp>},
617 {ReduceType::MUL, ElemType::Float,
619 spirv::GroupNonUniformFMulOp>},
622 spirv::GroupNonUniformUMinOp>},
623 {ReduceType::MINSI, ElemType::Integer,
625 spirv::GroupNonUniformSMinOp>},
626 {ReduceType::MINNUMF, ElemType::Float,
628 spirv::GroupNonUniformFMinOp>},
629 {ReduceType::MAXUI, ElemType::Integer,
631 spirv::GroupNonUniformUMaxOp>},
632 {ReduceType::MAXSI, ElemType::Integer,
634 spirv::GroupNonUniformSMaxOp>},
635 {ReduceType::MAXNUMF, ElemType::Float,
637 spirv::GroupNonUniformFMaxOp>},
638 {ReduceType::MINIMUMF, ElemType::Float,
640 spirv::GroupNonUniformFMinOp>},
641 {ReduceType::MAXIMUMF, ElemType::Float,
643 spirv::GroupNonUniformFMaxOp>}};
645 for (
const OpHandler &handler : handlers)
646 if (handler.kind == opType && elementType == handler.elemType)
647 return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
661 auto opType = op.getOp();
670 true, op.getUniform(), std::nullopt);
688 if (op.getClusterStride() > 1) {
690 op,
"lowering for cluster stride > 1 is not implemented");
693 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
697 rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
698 false, adaptor.getUniform(), op.getClusterSize());
711 static std::string
makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
717 name = (prefix + llvm::Twine(number++)).str();
718 }
while (moduleOp.lookupSymbol(name));
725 LogicalResult GPUPrintfConversion::matchAndRewrite(
726 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
729 Location loc = gpuPrintfOp.getLoc();
731 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
738 std::string globalVarName =
makeVarName(moduleOp, llvm::Twine(
"printfMsg"));
739 spirv::GlobalVariableOp globalVar;
741 IntegerType i8Type = rewriter.
getI8Type();
749 auto createSpecConstant = [&](
unsigned value) {
751 std::string specCstName =
752 makeVarName(moduleOp, llvm::Twine(globalVarName) +
"_sc");
754 return spirv::SpecConstantOp::create(
761 ConversionPatternRewriter::InsertionGuard guard(rewriter);
772 formatString.push_back(
'\0');
774 for (
char c : formatString) {
775 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
780 size_t contentSize = constituents.size();
782 spirv::SpecConstantCompositeOp specCstComposite;
785 std::string specCstCompositeName =
786 (llvm::Twine(globalVarName) +
"_scc").str();
788 specCstComposite = spirv::SpecConstantCompositeOp::create(
794 globalType, spirv::StorageClass::UniformConstant);
799 globalVar = spirv::GlobalVariableOp::create(
800 rewriter, loc, ptrType, globalVarName,
803 globalVar->setAttr(
"Constant", rewriter.
getUnitAttr());
807 Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
808 Value fmtStr = spirv::BitcastOp::create(
814 auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
816 spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
833 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
834 GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
835 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
836 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
837 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
838 LaunchConfigConversion<gpu::ThreadIdOp,
839 spirv::BuiltIn::LocalInvocationId>,
840 LaunchConfigConversion<gpu::GlobalIdOp,
841 spirv::BuiltIn::GlobalInvocationId>,
842 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
843 spirv::BuiltIn::SubgroupId>,
844 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
845 spirv::BuiltIn::NumSubgroups>,
846 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
847 spirv::BuiltIn::SubgroupSize>,
848 SingleDimLaunchConfigConversion<
849 gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
static LogicalResult getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, SmallVectorImpl< spirv::InterfaceVarABIAttr > &argABI)
Populates argABI with spirv.interface_var_abi attributes for lowering gpu.func to spirv....
static constexpr const char kSPIRVModule[]
static std::optional< Value > createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, gpu::AllReduceOperation opType, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, Value arg, bool isGroup, bool isUniform, std::optional< uint32_t > clusterSize)
static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef< spirv::InterfaceVarABIAttr > argABIInfo)
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
IntegerAttr getI8IntegerAttr(int8_t value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
static ArrayType get(Type elementType, unsigned elementCount)
An attribute that specifies the information regarding the interface variable: descriptor set,...
static PointerType get(Type pointeeType, StorageClass storageClass)
An attribute that specifies the target version, allowed extensions and capabilities,...
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
TargetEnvAttr getAttr() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr)
Returns whether the given SPIR-V target (described by TargetEnvAttr) needs ABI attributes for interfa...
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets the InterfaceVarABIAttr given its fields.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
EntryPointABIAttr lookupEntryPointABI(Operation *op)
Queries the entry point ABI on the nearest function-like op containing the given op.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
DenseI32ArrayAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
constexpr unsigned subgroupSize
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet & patterns
void populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating GPU Ops to SPIR-V ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.