24 #include "llvm/Support/FormatVariadic.h"
30 #define GEN_PASS_DEF_SPIRVCOMPOSITETYPELAYOUTPASS
31 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
36 class SPIRVGlobalVariableOpLayoutInfoDecoration
41 LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
45 auto ptrType = cast<spirv::PointerType>(op.getType());
46 auto pointeeType = cast<spirv::StructType>(ptrType.getPointeeType());
50 return op->emitError(llvm::formatv(
51 "failed to decorate (unsuported pointee type: '{0}')", pointeeType));
57 for (
const auto &attr : op->getAttrs()) {
58 if (attr.getName() ==
"type")
60 globalVarAttrs.push_back(attr);
69 class SPIRVAddressOfOpLayoutInfoDecoration
74 LogicalResult matchAndRewrite(spirv::AddressOfOp op,
76 auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
77 auto varName = op.getVariableAttr();
78 auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
86 template <
typename OpT>
92 matchAndRewrite(OpT op,
typename OpT::Adaptor adaptor,
95 [&] { op->setOperands(adaptor.getOperands()); });
102 patterns.add<SPIRVGlobalVariableOpLayoutInfoDecoration,
103 SPIRVAddressOfOpLayoutInfoDecoration,
104 SPIRVPassThroughConversion<spirv::AccessChainOp>,
105 SPIRVPassThroughConversion<spirv::LoadOp>,
106 SPIRVPassThroughConversion<spirv::StoreOp>>(
111 class DecorateSPIRVCompositeTypeLayoutPass
112 :
public spirv::impl::SPIRVCompositeTypeLayoutPassBase<
113 DecorateSPIRVCompositeTypeLayoutPass> {
114 void runOnOperation()
override;
118 void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
119 auto module = getOperation();
123 target.addLegalDialect<spirv::SPIRVDialect>();
124 target.addLegalOp<func::FuncOp>();
125 target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
126 [](spirv::GlobalVariableOp op) {
131 target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
136 target.addDynamicallyLegalOp<spirv::AccessChainOp, spirv::LoadOp,
138 for (
Value operand : op->getOperands()) {
139 auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
148 for (
auto spirvModule : module.getOps<spirv::ModuleOp>())
static void populateSPIRVLayoutInfoPatterns(RewritePatternSet &patterns)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static bool isLegalType(Type type)
Checks whether a type is legal in terms of Vulkan layout info decoration.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static PointerType get(Type pointeeType, StorageClass storageClass)
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...