MLIR 22.0.0git
DecorateCompositeTypeLayoutPass.cpp
Go to the documentation of this file.
1//===- DecorateCompositeTypeLayoutPass.cpp - Decorate composite type ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to decorate the composite types used by
10// composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
11// PushConstant storage classes with layout information. See SPIR-V spec
12// "2.16.2. Validation Rules for Shader Capabilities" for more details.
13//
14//===----------------------------------------------------------------------===//
15
17
23
24#include "llvm/Support/FormatVariadic.h"
25
26using namespace mlir;
27
28namespace mlir {
29namespace spirv {
30#define GEN_PASS_DEF_SPIRVCOMPOSITETYPELAYOUTPASS
31#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32} // namespace spirv
33} // namespace mlir
34
35namespace {
36class SPIRVGlobalVariableOpLayoutInfoDecoration
37 : public OpRewritePattern<spirv::GlobalVariableOp> {
38public:
39 using Base::Base;
40
41 LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
42 PatternRewriter &rewriter) const override {
43 SmallVector<NamedAttribute, 4> globalVarAttrs;
44
45 auto ptrType = cast<spirv::PointerType>(op.getType());
46 auto pointeeType = cast<spirv::StructType>(ptrType.getPointeeType());
47 spirv::StructType structType = VulkanLayoutUtils::decorateType(pointeeType);
48
49 if (!structType)
50 return op->emitError(llvm::formatv(
51 "failed to decorate (unsuported pointee type: '{0}')", pointeeType));
52
53 auto decoratedType =
54 spirv::PointerType::get(structType, ptrType.getStorageClass());
55
56 // Save all named attributes except "type" attribute.
57 for (const auto &attr : op->getAttrs()) {
58 if (attr.getName() == "type")
59 continue;
60 globalVarAttrs.push_back(attr);
61 }
62
63 rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
64 op, TypeAttr::get(decoratedType), globalVarAttrs);
65 return success();
66 }
67};
68
69class SPIRVAddressOfOpLayoutInfoDecoration
70 : public OpRewritePattern<spirv::AddressOfOp> {
71public:
72 using Base::Base;
73
74 LogicalResult matchAndRewrite(spirv::AddressOfOp op,
75 PatternRewriter &rewriter) const override {
76 auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
77 auto varName = op.getVariableAttr();
78 auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
79
80 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
81 op, varOp.getType(), SymbolRefAttr::get(varName.getAttr()));
82 return success();
83 }
84};
85
86template <typename OpT>
87class SPIRVPassThroughConversion : public OpConversionPattern<OpT> {
88public:
89 using OpConversionPattern<OpT>::OpConversionPattern;
90
91 LogicalResult
92 matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
93 ConversionPatternRewriter &rewriter) const override {
94 rewriter.modifyOpInPlace(op,
95 [&] { op->setOperands(adaptor.getOperands()); });
96 return success();
97 }
98};
99} // namespace
100
102 patterns.add<SPIRVGlobalVariableOpLayoutInfoDecoration,
103 SPIRVAddressOfOpLayoutInfoDecoration,
104 SPIRVPassThroughConversion<spirv::AccessChainOp>,
105 SPIRVPassThroughConversion<spirv::LoadOp>,
106 SPIRVPassThroughConversion<spirv::StoreOp>>(
107 patterns.getContext());
108}
110namespace {
111class DecorateSPIRVCompositeTypeLayoutPass
113 DecorateSPIRVCompositeTypeLayoutPass> {
114 void runOnOperation() override;
116} // namespace
117
118void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
119 auto module = getOperation();
120 RewritePatternSet patterns(module.getContext());
122 ConversionTarget target(*(module.getContext()));
123 target.addLegalDialect<spirv::SPIRVDialect>();
124 target.addLegalOp<func::FuncOp>();
125 target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
126 [](spirv::GlobalVariableOp op) {
127 return VulkanLayoutUtils::isLegalType(op.getType());
128 });
129
130 // Change the type for the direct users.
131 target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
132 return VulkanLayoutUtils::isLegalType(op.getPointer().getType());
133 });
134
135 // Change the type for the indirect users.
136 target.addDynamicallyLegalOp<spirv::AccessChainOp, spirv::LoadOp,
137 spirv::StoreOp>([&](Operation *op) {
138 for (Value operand : op->getOperands()) {
139 auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
140 if (addrOp &&
141 !VulkanLayoutUtils::isLegalType(addrOp.getPointer().getType()))
142 return false;
144 return true;
145 });
146
147 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
148 for (auto spirvModule : module.getOps<spirv::ModuleOp>())
149 if (failed(applyFullConversion(spirvModule, target, frozenPatterns)))
150 signalPassFailure();
151}
return success()
static void populateSPIRVLayoutInfoPatterns(RewritePatternSet &patterns)
This class represents a frozen set of patterns that can be processed by a pattern applicator.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static bool isLegalType(Type type)
Checks whether a type is legal in terms of Vulkan layout info decoration.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static PointerType get(Type pointeeType, StorageClass storageClass)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...