25 #include "llvm/ADT/SetVector.h"
29 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
37 static spirv::GlobalVariableOp
41 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
48 funcOp.getName().str() +
"_arg_" + std::to_string(argIndex);
53 auto varType = funcOp.getFunctionType().getInput(argIndex);
54 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
61 auto varPtrType = cast<spirv::PointerType>(varType);
62 auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
74 return builder.
create<spirv::GlobalVariableOp>(
84 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
96 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
98 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
104 const spirv::StorageClass storageClass =
105 cast<spirv::PointerType>(var.getType()).getStorageClass();
106 if ((targetEnvAttr && targetEnv.
getVersion() >= spirv::Version::V_1_4) ||
108 {spirv::StorageClass::Input, spirv::StorageClass::Output},
110 interfaceVarSet.insert(var.getOperation());
113 for (
auto &var : interfaceVarSet) {
115 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
124 auto entryPointAttr =
125 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
126 if (!entryPointAttr) {
134 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
144 FailureOr<spirv::ExecutionModel> executionModel =
146 if (failed(executionModel))
147 return funcOp.emitRemark(
"lower entry point failure: could not select "
148 "execution model based on 'spirv.target_env'");
150 builder.
create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
155 std::optional<ArrayRef<spirv::Capability>> caps =
156 spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
157 if (!caps || targetEnv.
allows(*caps)) {
158 builder.
create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
159 spirv::ExecutionMode::LocalSize,
160 workgroupSizeAttr.asArrayRef());
164 entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
167 if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
168 std::optional<ArrayRef<spirv::Capability>> caps =
169 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
170 if (!caps || targetEnv.
allows(*caps)) {
171 builder.
create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
172 spirv::ExecutionMode::SubgroupSize,
176 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
177 std::nullopt, entryPointAttr.getTargetWidth());
180 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
181 std::optional<ArrayRef<spirv::Capability>> caps =
182 spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
183 if (!caps || targetEnv.
allows(*caps)) {
184 builder.
create<spirv::ExecutionModeOp>(
185 funcOp.getLoc(), funcOp,
186 spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
189 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
190 entryPointAttr.getSubgroupSize(), std::nullopt);
193 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
194 entryPointAttr.getTargetWidth())
195 funcOp->setAttr(entryPointAttrName, entryPointAttr);
197 funcOp->removeAttr(entryPointAttrName);
214 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
219 class LowerABIAttributesPass final
220 :
public spirv::impl::SPIRVLowerABIAttributesPassBase<
221 LowerABIAttributesPass> {
222 void runOnOperation()
override;
226 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
227 spirv::FuncOp funcOp, OpAdaptor adaptor,
229 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
235 funcOp.getFunctionType().getNumInputs());
237 auto &
typeConverter = *getTypeConverter<SPIRVTypeConverter>();
245 for (
const auto &argType :
248 argType.index(), attrName);
257 rewriter, funcOp, argType.index(), abiInfo);
263 rewriter.
create<spirv::AddressOfOp>(funcOp.getLoc(), var);
270 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
273 auto loadPtr = rewriter.
create<spirv::AccessChainOp>(
274 funcOp.getLoc(), replacement, zero.getConstant());
275 replacement = rewriter.
create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
277 signatureConverter.remapInput(argType.index(), replacement);
280 &signatureConverter)))
286 signatureConverter.getConvertedTypes(), std::nullopt));
291 void LowerABIAttributesPass::runOnOperation() {
294 spirv::ModuleOp module = getOperation();
298 if (!targetEnvAttr) {
299 module->emitOpError(
"missing SPIR-V target env attribute");
300 return signalPassFailure();
310 if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
312 return builder.
create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
320 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
322 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
323 if (op.getArgAttr(i, attrName))
328 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
330 spirv::SPIRVDialect::getDialectNamespace();
333 return signalPassFailure();
340 module.walk([&](spirv::FuncOp funcOp) {
341 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
342 entryPointFns.push_back(funcOp);
345 for (
auto fn : entryPointFns) {
347 return signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static LogicalResult getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
This class describes a specific conversion target.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
TypeConverter & typeConverter
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.