25 #include "llvm/Support/FormatVariadic.h"
29 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
37 static spirv::GlobalVariableOp
41 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
48 funcOp.getName().str() +
"_arg_" + std::to_string(argIndex);
53 auto varType = funcOp.getFunctionType().getInput(argIndex);
54 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
61 auto varPtrType = cast<spirv::PointerType>(varType);
62 Type pointeeType = varPtrType.getPointeeType();
67 if (isa<spirv::SampledImageType>(pointeeType))
68 return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
72 auto varPointeeType = cast<spirv::StructType>(pointeeType);
84 return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
90 static spirv::GlobalVariableOp
92 unsigned index,
bool isArg,
94 auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
100 std::string varName = llvm::formatv(
"{}_{}_{}", graphOp.getName(),
101 isArg ?
"arg" :
"res", index);
103 Type varType = isArg ? graphOp.getFunctionType().getInput(index)
104 : graphOp.getFunctionType().getResult(index);
108 abiInfo.
getStorageClass().value_or(spirv::StorageClass::UniformConstant));
110 return spirv::GlobalVariableOp::create(builder, graphOp.getLoc(), pointerType,
120 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
132 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
134 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
140 const spirv::StorageClass storageClass =
141 cast<spirv::PointerType>(var.getType()).getStorageClass();
142 if ((targetEnvAttr && targetEnv.
getVersion() >= spirv::Version::V_1_4) ||
144 {spirv::StorageClass::Input, spirv::StorageClass::Output},
146 interfaceVarSet.insert(var.getOperation());
149 for (
auto &var : interfaceVarSet) {
151 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
160 auto entryPointAttr =
161 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
162 if (!entryPointAttr) {
170 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
180 FailureOr<spirv::ExecutionModel> executionModel =
182 if (
failed(executionModel))
183 return funcOp.emitRemark(
"lower entry point failure: could not select "
184 "execution model based on 'spirv.target_env'");
186 spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
191 std::optional<ArrayRef<spirv::Capability>> caps =
192 spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
193 if (!caps || targetEnv.
allows(*caps)) {
194 spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
195 spirv::ExecutionMode::LocalSize,
196 workgroupSizeAttr.asArrayRef());
200 entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
203 if (std::optional<int>
subgroupSize = entryPointAttr.getSubgroupSize()) {
204 std::optional<ArrayRef<spirv::Capability>> caps =
205 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
206 if (!caps || targetEnv.
allows(*caps)) {
207 spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
208 spirv::ExecutionMode::SubgroupSize,
212 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
213 std::nullopt, entryPointAttr.getTargetWidth());
216 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
217 std::optional<ArrayRef<spirv::Capability>> caps =
218 spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
219 if (!caps || targetEnv.
allows(*caps)) {
220 spirv::ExecutionModeOp::create(
221 builder, funcOp.getLoc(), funcOp,
222 spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
225 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
226 entryPointAttr.getSubgroupSize(), std::nullopt);
229 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
230 entryPointAttr.getTargetWidth())
231 funcOp->setAttr(entryPointAttrName, entryPointAttr);
233 funcOp->removeAttr(entryPointAttrName);
250 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
259 class ProcessGraphInterfaceVarABI final
265 matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
270 class LowerABIAttributesPass final
271 :
public spirv::impl::SPIRVLowerABIAttributesPassBase<
272 LowerABIAttributesPass> {
273 void runOnOperation()
override;
277 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
278 spirv::FuncOp funcOp, OpAdaptor adaptor,
280 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
286 funcOp.getFunctionType().getNumInputs());
288 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
289 auto indexType = typeConverter.getIndexType();
296 for (
const auto &argType :
299 argType.index(), attrName);
308 rewriter, funcOp, argType.index(), abiInfo);
314 spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
321 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
324 auto loadPtr = spirv::AccessChainOp::create(
325 rewriter, funcOp.getLoc(), replacement, zero.getConstant());
326 replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
328 signatureConverter.remapInput(argType.index(), replacement);
331 &signatureConverter)))
342 LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
343 spirv::GraphARMOp graphOp, OpAdaptor adaptor,
346 if (!graphOp.getEntryPoint().value_or(
false))
350 graphOp.getFunctionType().getNumInputs());
356 unsigned numInputs = graphOp.getFunctionType().getNumInputs();
357 unsigned numResults = graphOp.getFunctionType().getNumResults();
358 for (
unsigned index = 0; index < numInputs; ++index) {
364 rewriter, graphOp, index,
true, abiInfo);
367 interfaceVars.push_back(
371 for (
unsigned index = 0; index < numResults; ++index) {
377 rewriter, graphOp, index,
false, abiInfo);
380 interfaceVars.push_back(
386 for (
unsigned index = 0; index < numInputs; ++index) {
387 graphOp.removeArgAttr(index, attrName);
389 for (
unsigned index = 0; index < numResults; ++index) {
390 graphOp.removeResultAttr(index, rewriter.
getStringAttr(attrName));
394 spirv::GraphEntryPointARMOp::create(rewriter, graphOp.getLoc(), graphOp,
399 void LowerABIAttributesPass::runOnOperation() {
402 spirv::ModuleOp module = getOperation();
406 if (!targetEnvAttr) {
407 module->emitOpError(
"missing SPIR-V target env attribute");
408 return signalPassFailure();
415 typeConverter.addSourceMaterialization([](
OpBuilder &builder,
418 if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
420 return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
424 patterns.add<ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
425 typeConverter, context);
429 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
431 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
432 if (op.getArgAttr(i, attrName))
436 target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
438 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
439 if (op.getArgAttr(i, attrName))
441 for (
unsigned i = 0, e = op.getNumResults(); i < e; ++i)
442 if (op.getResultAttr(i, attrName))
448 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
450 spirv::SPIRVDialect::getDialectNamespace();
453 return signalPassFailure();
460 module.walk([&](spirv::FuncOp funcOp) {
461 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
462 entryPointFns.push_back(funcOp);
465 for (
auto fn : entryPointFns) {
467 return signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static spirv::GlobalVariableOp createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp, unsigned index, bool isArg, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument or result based on the ABI info.
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
static LogicalResult getInterfaceVariables(mlir::FunctionOpInterface funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
This class describes a specific conversion target.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
constexpr unsigned subgroupSize
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.