25 #include "llvm/ADT/SetVector.h"
29 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
37 static spirv::GlobalVariableOp
41 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
48 funcOp.getName().str() +
"_arg_" + std::to_string(argIndex);
53 auto varType = funcOp.getFunctionType().getInput(argIndex);
54 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
61 auto varPtrType = cast<spirv::PointerType>(varType);
62 auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
74 return builder.
create<spirv::GlobalVariableOp>(
84 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
93 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
95 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
101 switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
102 case spirv::StorageClass::Input:
103 case spirv::StorageClass::Output:
104 interfaceVarSet.insert(var.getOperation());
110 for (
auto &var : interfaceVarSet) {
112 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
121 auto entryPointAttr =
122 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
123 if (!entryPointAttr) {
128 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
142 if (
failed(executionModel))
143 return funcOp.emitRemark(
"lower entry point failure: could not select "
144 "execution model based on 'spirv.target_env'");
146 builder.
create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
151 std::optional<ArrayRef<spirv::Capability>> caps =
152 spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
153 if (!caps || targetEnv.
allows(*caps)) {
154 builder.
create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
155 spirv::ExecutionMode::LocalSize,
156 workgroupSizeAttr.asArrayRef());
160 entryPointAttr.getSubgroupSize());
163 if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
164 std::optional<ArrayRef<spirv::Capability>> caps =
165 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
166 if (!caps || targetEnv.
allows(*caps)) {
167 builder.
create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
168 spirv::ExecutionMode::SubgroupSize,
172 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
176 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize())
177 funcOp->setAttr(entryPointAttrName, entryPointAttr);
179 funcOp->removeAttr(entryPointAttrName);
196 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
201 class LowerABIAttributesPass final
202 :
public spirv::impl::SPIRVLowerABIAttributesPassBase<
203 LowerABIAttributesPass> {
204 void runOnOperation()
override;
209 spirv::FuncOp funcOp, OpAdaptor adaptor,
211 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
217 funcOp.getFunctionType().getNumInputs());
219 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
220 auto indexType = typeConverter.getIndexType();
223 for (
const auto &argType :
226 argType.index(), attrName);
235 rewriter, funcOp, argType.index(), abiInfo);
243 rewriter.
create<spirv::AddressOfOp>(funcOp.getLoc(), var);
250 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
253 auto loadPtr = rewriter.
create<spirv::AccessChainOp>(
254 funcOp.getLoc(), replacement, zero.getConstant());
255 replacement = rewriter.
create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
257 signatureConverter.remapInput(argType.index(), replacement);
260 &signatureConverter)))
266 signatureConverter.getConvertedTypes(), std::nullopt));
271 void LowerABIAttributesPass::runOnOperation() {
274 spirv::ModuleOp module = getOperation();
278 if (!targetEnvAttr) {
279 module->emitOpError(
"missing SPIR-V target env attribute");
280 return signalPassFailure();
287 typeConverter.addSourceMaterialization([](
OpBuilder &builder,
290 if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
292 return builder.
create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
296 patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
300 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
302 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
303 if (op.getArgAttr(i, attrName))
308 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
310 spirv::SPIRVDialect::getDialectNamespace();
313 return signalPassFailure();
320 module.walk([&](spirv::FuncOp funcOp) {
321 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
322 entryPointFns.push_back(funcOp);
325 for (
auto fn : entryPointFns) {
327 return signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static LogicalResult getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
This class describes a specific conversion target.
StringRef getNamespace() const
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
bool allows(Capability) const
Returns true if the given capability is allowed.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.