25 #include "llvm/ADT/SetVector.h"
29 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
37 static spirv::GlobalVariableOp
41 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
48 funcOp.getName().str() +
"_arg_" + std::to_string(argIndex);
53 auto varType = funcOp.getFunctionType().getInput(argIndex);
54 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
61 auto varPtrType = cast<spirv::PointerType>(varType);
62 auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
74 return builder.
create<spirv::GlobalVariableOp>(
84 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
93 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
95 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
101 switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
102 case spirv::StorageClass::Input:
103 case spirv::StorageClass::Output:
104 interfaceVarSet.insert(var.getOperation());
110 for (
auto &var : interfaceVarSet) {
112 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
121 auto entryPointAttr =
122 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
123 if (!entryPointAttr) {
128 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
140 FailureOr<spirv::ExecutionModel> executionModel =
142 if (failed(executionModel))
143 return funcOp.emitRemark(
"lower entry point failure: could not select "
144 "execution model based on 'spirv.target_env'");
146 builder.
create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
151 std::optional<ArrayRef<spirv::Capability>> caps =
152 spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
153 if (!caps || targetEnv.
allows(*caps)) {
154 builder.
create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
155 spirv::ExecutionMode::LocalSize,
156 workgroupSizeAttr.asArrayRef());
160 entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
163 if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
164 std::optional<ArrayRef<spirv::Capability>> caps =
165 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
166 if (!caps || targetEnv.
allows(*caps)) {
167 builder.
create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
168 spirv::ExecutionMode::SubgroupSize,
172 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
173 std::nullopt, entryPointAttr.getTargetWidth());
176 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
177 std::optional<ArrayRef<spirv::Capability>> caps =
178 spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
179 if (!caps || targetEnv.
allows(*caps)) {
180 builder.
create<spirv::ExecutionModeOp>(
181 funcOp.getLoc(), funcOp,
182 spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
185 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
186 entryPointAttr.getSubgroupSize(), std::nullopt);
189 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
190 entryPointAttr.getTargetWidth())
191 funcOp->setAttr(entryPointAttrName, entryPointAttr);
193 funcOp->removeAttr(entryPointAttrName);
210 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
215 class LowerABIAttributesPass final
216 :
public spirv::impl::SPIRVLowerABIAttributesPassBase<
217 LowerABIAttributesPass> {
218 void runOnOperation()
override;
222 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
223 spirv::FuncOp funcOp, OpAdaptor adaptor,
225 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
231 funcOp.getFunctionType().getNumInputs());
233 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
234 auto indexType = typeConverter.getIndexType();
237 for (
const auto &argType :
240 argType.index(), attrName);
249 rewriter, funcOp, argType.index(), abiInfo);
257 rewriter.
create<spirv::AddressOfOp>(funcOp.getLoc(), var);
264 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
267 auto loadPtr = rewriter.
create<spirv::AccessChainOp>(
268 funcOp.getLoc(), replacement, zero.getConstant());
269 replacement = rewriter.
create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
271 signatureConverter.remapInput(argType.index(), replacement);
274 &signatureConverter)))
280 signatureConverter.getConvertedTypes(), std::nullopt));
285 void LowerABIAttributesPass::runOnOperation() {
288 spirv::ModuleOp module = getOperation();
292 if (!targetEnvAttr) {
293 module->emitOpError(
"missing SPIR-V target env attribute");
294 return signalPassFailure();
301 typeConverter.addSourceMaterialization([](
OpBuilder &builder,
304 if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
306 return builder.
create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
310 patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
314 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
316 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
317 if (op.getArgAttr(i, attrName))
322 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
324 spirv::SPIRVDialect::getDialectNamespace();
327 return signalPassFailure();
334 module.walk([&](spirv::FuncOp funcOp) {
335 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
336 entryPointFns.push_back(funcOp);
339 for (
auto fn : entryPointFns) {
341 return signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static LogicalResult getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
This class describes a specific conversion target.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
bool allows(Capability) const
Returns true if the given capability is allowed.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.