28 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
29 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
36 static spirv::GlobalVariableOp
40 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
47 funcOp.getName().str() +
"_arg_" + std::to_string(argIndex);
52 auto varType = funcOp.getFunctionType().getInput(argIndex);
53 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
60 auto varPtrType = cast<spirv::PointerType>(varType);
61 Type pointeeType = varPtrType.getPointeeType();
66 if (isa<spirv::SampledImageType>(pointeeType))
67 return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
71 auto varPointeeType = cast<spirv::StructType>(pointeeType);
83 return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
93 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
105 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
107 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
113 const spirv::StorageClass storageClass =
114 cast<spirv::PointerType>(var.getType()).getStorageClass();
115 if ((targetEnvAttr && targetEnv.
getVersion() >= spirv::Version::V_1_4) ||
117 {spirv::StorageClass::Input, spirv::StorageClass::Output},
119 interfaceVarSet.insert(var.getOperation());
122 for (
auto &var : interfaceVarSet) {
124 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
133 auto entryPointAttr =
134 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
135 if (!entryPointAttr) {
143 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
153 FailureOr<spirv::ExecutionModel> executionModel =
155 if (
failed(executionModel))
156 return funcOp.emitRemark(
"lower entry point failure: could not select "
157 "execution model based on 'spirv.target_env'");
159 spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
164 std::optional<ArrayRef<spirv::Capability>> caps =
165 spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
166 if (!caps || targetEnv.
allows(*caps)) {
167 spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
168 spirv::ExecutionMode::LocalSize,
169 workgroupSizeAttr.asArrayRef());
173 entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
176 if (std::optional<int>
subgroupSize = entryPointAttr.getSubgroupSize()) {
177 std::optional<ArrayRef<spirv::Capability>> caps =
178 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
179 if (!caps || targetEnv.
allows(*caps)) {
180 spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
181 spirv::ExecutionMode::SubgroupSize,
185 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
186 std::nullopt, entryPointAttr.getTargetWidth());
189 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
190 std::optional<ArrayRef<spirv::Capability>> caps =
191 spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
192 if (!caps || targetEnv.
allows(*caps)) {
193 spirv::ExecutionModeOp::create(
194 builder, funcOp.getLoc(), funcOp,
195 spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
198 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
199 entryPointAttr.getSubgroupSize(), std::nullopt);
202 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
203 entryPointAttr.getTargetWidth())
204 funcOp->setAttr(entryPointAttrName, entryPointAttr);
206 funcOp->removeAttr(entryPointAttrName);
223 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
228 class LowerABIAttributesPass final
229 :
public spirv::impl::SPIRVLowerABIAttributesPassBase<
230 LowerABIAttributesPass> {
231 void runOnOperation()
override;
235 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
236 spirv::FuncOp funcOp, OpAdaptor adaptor,
238 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
244 funcOp.getFunctionType().getNumInputs());
246 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
247 auto indexType = typeConverter.getIndexType();
254 for (
const auto &argType :
257 argType.index(), attrName);
266 rewriter, funcOp, argType.index(), abiInfo);
272 spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
279 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
282 auto loadPtr = spirv::AccessChainOp::create(
283 rewriter, funcOp.getLoc(), replacement, zero.getConstant());
284 replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
286 signatureConverter.remapInput(argType.index(), replacement);
289 &signatureConverter)))
300 void LowerABIAttributesPass::runOnOperation() {
303 spirv::ModuleOp module = getOperation();
307 if (!targetEnvAttr) {
308 module->emitOpError(
"missing SPIR-V target env attribute");
309 return signalPassFailure();
316 typeConverter.addSourceMaterialization([](
OpBuilder &builder,
319 if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
321 return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
325 patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
329 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
331 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
332 if (op.getArgAttr(i, attrName))
337 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
339 spirv::SPIRVDialect::getDialectNamespace();
342 return signalPassFailure();
349 module.walk([&](spirv::FuncOp funcOp) {
350 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
351 entryPointFns.push_back(funcOp);
354 for (
auto fn : entryPointFns) {
356 return signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static LogicalResult getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
This class describes a specific conversion target.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
constexpr unsigned subgroupSize
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.