24#include "llvm/Support/FormatVariadic.h"
30#define GEN_PASS_DEF_SPIRVCOMPOSITETYPELAYOUTPASS
31#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
36class SPIRVGlobalVariableOpLayoutInfoDecoration
41 LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
42 PatternRewriter &rewriter)
const override {
43 SmallVector<NamedAttribute, 4> globalVarAttrs;
45 auto ptrType = cast<spirv::PointerType>(op.getType());
46 auto pointeeType = cast<spirv::StructType>(ptrType.getPointeeType());
50 return op->emitError(llvm::formatv(
51 "failed to decorate (unsuported pointee type: '{0}')", pointeeType));
57 for (
const auto &attr : op->getAttrs()) {
58 if (attr.getName() ==
"type")
60 globalVarAttrs.push_back(attr);
64 op, TypeAttr::get(decoratedType), globalVarAttrs);
69class SPIRVAddressOfOpLayoutInfoDecoration
74 LogicalResult matchAndRewrite(spirv::AddressOfOp op,
75 PatternRewriter &rewriter)
const override {
76 auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
77 auto varName = op.getVariableAttr();
78 auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
81 op, varOp.getType(), SymbolRefAttr::get(varName.getAttr()));
86template <
typename OpT>
87class SPIRVPassThroughConversion :
public OpConversionPattern<OpT> {
89 using OpConversionPattern<OpT>::OpConversionPattern;
92 matchAndRewrite(OpT op,
typename OpT::Adaptor adaptor,
93 ConversionPatternRewriter &rewriter)
const override {
94 rewriter.modifyOpInPlace(op,
95 [&] { op->setOperands(adaptor.getOperands()); });
102 patterns.add<SPIRVGlobalVariableOpLayoutInfoDecoration,
103 SPIRVAddressOfOpLayoutInfoDecoration,
104 SPIRVPassThroughConversion<spirv::AccessChainOp>,
105 SPIRVPassThroughConversion<spirv::LoadOp>,
106 SPIRVPassThroughConversion<spirv::StoreOp>>(
111class DecorateSPIRVCompositeTypeLayoutPass
113 DecorateSPIRVCompositeTypeLayoutPass> {
118void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
119 auto module = getOperation();
123 target.addLegalDialect<spirv::SPIRVDialect>();
125 target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
126 [](spirv::GlobalVariableOp op) {
131 target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
136 target.addDynamicallyLegalOp<spirv::AccessChainOp, spirv::LoadOp,
138 for (
Value operand : op->getOperands()) {
139 auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
148 for (
auto spirvModule : module.getOps<spirv::ModuleOp>())
149 if (failed(applyFullConversion(spirvModule,
target, frozenPatterns)))
static void populateSPIRVLayoutInfoPatterns(RewritePatternSet &patterns)
This class represents a frozen set of patterns that can be processed by a pattern applicator.
Operation is the basic unit of execution within MLIR.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static bool isLegalType(Type type)
Checks whether a type is legal in terms of Vulkan layout info decoration.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static PointerType get(Type pointeeType, StorageClass storageClass)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...