25#include "llvm/Support/FormatVariadic.h"
29#define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
37static spirv::GlobalVariableOp
41 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
48 funcOp.getName().str() +
"_arg_" + std::to_string(argIndex);
53 auto varType = funcOp.getFunctionType().getInput(argIndex);
54 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
61 auto varPtrType = cast<spirv::PointerType>(varType);
62 Type pointeeType = varPtrType.getPointeeType();
67 if (isa<spirv::SampledImageType>(pointeeType))
68 return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
72 auto varPointeeType = cast<spirv::StructType>(pointeeType);
84 return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
90static spirv::GlobalVariableOp
92 unsigned index,
bool isArg,
94 auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
100 std::string varName = llvm::formatv(
"{}_{}_{}", graphOp.getName(),
101 isArg ?
"arg" :
"res",
index);
103 Type varType = isArg ? graphOp.getFunctionType().getInput(
index)
104 : graphOp.getFunctionType().getResult(
index);
108 abiInfo.
getStorageClass().value_or(spirv::StorageClass::UniformConstant));
110 return spirv::GlobalVariableOp::create(builder, graphOp.getLoc(), pointerType,
120 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
132 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
134 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
140 const spirv::StorageClass storageClass =
141 cast<spirv::PointerType>(var.getType()).getStorageClass();
142 if ((targetEnvAttr && targetEnv.
getVersion() >= spirv::Version::V_1_4) ||
144 {spirv::StorageClass::Input, spirv::StorageClass::Output},
146 interfaceVarSet.insert(var.getOperation());
149 for (
auto &var : interfaceVarSet) {
150 interfaceVars.push_back(SymbolRefAttr::get(
151 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
160 auto entryPointAttr =
161 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
162 if (!entryPointAttr) {
170 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
180 FailureOr<spirv::ExecutionModel> executionModel =
182 if (failed(executionModel))
183 return funcOp.emitRemark(
"lower entry point failure: could not select "
184 "execution model based on 'spirv.target_env'");
186 spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
191 std::optional<ArrayRef<spirv::Capability>> caps =
192 spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
193 if (!caps || targetEnv.
allows(*caps)) {
194 spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
195 spirv::ExecutionMode::LocalSize,
196 workgroupSizeAttr.asArrayRef());
198 entryPointAttr = spirv::EntryPointABIAttr::get(
200 entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
203 if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
204 std::optional<ArrayRef<spirv::Capability>> caps =
205 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
206 if (!caps || targetEnv.
allows(*caps)) {
207 spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
208 spirv::ExecutionMode::SubgroupSize,
211 entryPointAttr = spirv::EntryPointABIAttr::get(
212 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
213 std::nullopt, entryPointAttr.getTargetWidth());
216 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
217 std::optional<ArrayRef<spirv::Capability>> caps =
218 spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
219 if (!caps || targetEnv.
allows(*caps)) {
220 spirv::ExecutionModeOp::create(
221 builder, funcOp.getLoc(), funcOp,
222 spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
224 entryPointAttr = spirv::EntryPointABIAttr::get(
225 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
226 entryPointAttr.getSubgroupSize(), std::nullopt);
229 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
230 entryPointAttr.getTargetWidth())
231 funcOp->setAttr(entryPointAttrName, entryPointAttr);
233 funcOp->removeAttr(entryPointAttrName);
245class ProcessInterfaceVarABI final :
public OpConversionPattern<spirv::FuncOp> {
250 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
251 ConversionPatternRewriter &rewriter)
const override;
259class ProcessGraphInterfaceVarABI final
260 :
public OpConversionPattern<spirv::GraphARMOp> {
262 using OpConversionPattern::OpConversionPattern;
265 matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
266 ConversionPatternRewriter &rewriter)
const override;
270class LowerABIAttributesPass final
272 LowerABIAttributesPass> {
273 void runOnOperation()
override;
277LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
278 spirv::FuncOp funcOp, OpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter)
const {
280 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
285 TypeConverter::SignatureConversion signatureConverter(
286 funcOp.getFunctionType().getNumInputs());
288 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
289 auto indexType = typeConverter.getIndexType();
294 rewriter.setInsertionPointToStart(&funcOp.front());
296 for (
const auto &argType :
297 llvm::enumerate(funcOp.getFunctionType().getInputs())) {
299 argType.index(), attrName);
308 rewriter, funcOp, argType.index(), abiInfo);
314 spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
321 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
323 spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
324 auto loadPtr = spirv::AccessChainOp::create(
325 rewriter, funcOp.getLoc(),
replacement, zero.getConstant());
326 replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
328 signatureConverter.remapInput(argType.index(),
replacement);
330 if (
failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
331 &signatureConverter)))
335 rewriter.modifyOpInPlace(funcOp, [&] {
337 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
342LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
343 spirv::GraphARMOp graphOp, OpAdaptor adaptor,
344 ConversionPatternRewriter &rewriter)
const {
346 if (!graphOp.getEntryPoint().value_or(
false))
349 TypeConverter::SignatureConversion signatureConverter(
350 graphOp.getFunctionType().getNumInputs());
353 SmallVector<Attribute, 4> interfaceVars;
356 unsigned numInputs = graphOp.getFunctionType().getNumInputs();
357 unsigned numResults = graphOp.getFunctionType().getNumResults();
358 for (
unsigned index = 0; index < numInputs; ++index) {
360 graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
364 rewriter, graphOp, index,
true, abiInfo);
367 interfaceVars.push_back(
368 SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
371 for (
unsigned index = 0; index < numResults; ++index) {
372 auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
377 rewriter, graphOp, index,
false, abiInfo);
380 interfaceVars.push_back(
381 SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
385 rewriter.modifyOpInPlace(graphOp, [&] {
386 for (
unsigned index = 0; index < numInputs; ++index) {
387 graphOp.removeArgAttr(index, attrName);
389 for (
unsigned index = 0; index < numResults; ++index) {
390 graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
394 spirv::GraphEntryPointARMOp::create(rewriter, graphOp.getLoc(), graphOp,
399void LowerABIAttributesPass::runOnOperation() {
402 spirv::ModuleOp module = getOperation();
406 if (!targetEnvAttr) {
407 module->emitOpError("missing SPIR-V target env attribute");
408 return signalPassFailure();
410 spirv::TargetEnv targetEnv(targetEnvAttr);
412 SPIRVTypeConverter typeConverter(targetEnv);
415 typeConverter.addSourceMaterialization([](OpBuilder &builder,
416 spirv::PointerType type,
418 if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
420 return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
423 RewritePatternSet
patterns(context);
424 patterns.add<ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
425 typeConverter, context);
427 ConversionTarget
target(*context);
429 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
431 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
432 if (op.getArgAttr(i, attrName))
436 target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
438 for (
unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
439 if (op.getArgAttr(i, attrName))
441 for (
unsigned i = 0, e = op.getNumResults(); i < e; ++i)
442 if (op.getResultAttr(i, attrName))
448 target.markUnknownOpDynamicallyLegal([](Operation *op) {
450 spirv::SPIRVDialect::getDialectNamespace();
453 return signalPassFailure();
457 OpBuilder builder(context);
458 SmallVector<spirv::FuncOp, 1> entryPointFns;
460 module.walk([&](spirv::FuncOp funcOp) {
461 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
462 entryPointFns.push_back(funcOp);
465 for (
auto fn : entryPointFns) {
467 return signalPassFailure();
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static spirv::GlobalVariableOp createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp, unsigned index, bool isArg, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument or result based on the ABI info.
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
static LogicalResult getInterfaceVariables(mlir::FunctionOpInterface funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
StringRef getNamespace() const
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr