MLIR  19.0.0git
DecorateCompositeTypeLayoutPass.cpp
Go to the documentation of this file.
1 //===- DecorateCompositeTypeLayoutPass.cpp - Decorate composite type ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to decorate the composite types used by
10 // composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
11 // PushConstant storage classes with layout information. See SPIR-V spec
12 // "2.16.2. Validation Rules for Shader Capabilities" for more details.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 
23 
24 #include "llvm/Support/FormatVariadic.h"
25 
26 using namespace mlir;
27 
28 namespace mlir {
29 namespace spirv {
30 #define GEN_PASS_DEF_SPIRVCOMPOSITETYPELAYOUTPASS
31 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
32 } // namespace spirv
33 } // namespace mlir
34 
35 namespace {
36 class SPIRVGlobalVariableOpLayoutInfoDecoration
37  : public OpRewritePattern<spirv::GlobalVariableOp> {
38 public:
40 
41  LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
42  PatternRewriter &rewriter) const override {
43  SmallVector<NamedAttribute, 4> globalVarAttrs;
44 
45  auto ptrType = cast<spirv::PointerType>(op.getType());
46  auto pointeeType = cast<spirv::StructType>(ptrType.getPointeeType());
47  spirv::StructType structType = VulkanLayoutUtils::decorateType(pointeeType);
48 
49  if (!structType)
50  return op->emitError(llvm::formatv(
51  "failed to decorate (unsuported pointee type: '{0}')", pointeeType));
52 
53  auto decoratedType =
54  spirv::PointerType::get(structType, ptrType.getStorageClass());
55 
56  // Save all named attributes except "type" attribute.
57  for (const auto &attr : op->getAttrs()) {
58  if (attr.getName() == "type")
59  continue;
60  globalVarAttrs.push_back(attr);
61  }
62 
63  rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
64  op, TypeAttr::get(decoratedType), globalVarAttrs);
65  return success();
66  }
67 };
68 
69 class SPIRVAddressOfOpLayoutInfoDecoration
70  : public OpRewritePattern<spirv::AddressOfOp> {
71 public:
73 
74  LogicalResult matchAndRewrite(spirv::AddressOfOp op,
75  PatternRewriter &rewriter) const override {
76  auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
77  auto varName = op.getVariableAttr();
78  auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
79 
80  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
81  op, varOp.getType(), SymbolRefAttr::get(varName.getAttr()));
82  return success();
83  }
84 };
85 
86 template <typename OpT>
87 class SPIRVPassThroughConversion : public OpConversionPattern<OpT> {
88 public:
90 
92  matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
93  ConversionPatternRewriter &rewriter) const override {
94  rewriter.modifyOpInPlace(op,
95  [&] { op->setOperands(adaptor.getOperands()); });
96  return success();
97  }
98 };
99 } // namespace
100 
102  patterns.add<SPIRVGlobalVariableOpLayoutInfoDecoration,
103  SPIRVAddressOfOpLayoutInfoDecoration,
104  SPIRVPassThroughConversion<spirv::AccessChainOp>,
105  SPIRVPassThroughConversion<spirv::LoadOp>,
106  SPIRVPassThroughConversion<spirv::StoreOp>>(
107  patterns.getContext());
108 }
109 
110 namespace {
111 class DecorateSPIRVCompositeTypeLayoutPass
112  : public spirv::impl::SPIRVCompositeTypeLayoutPassBase<
113  DecorateSPIRVCompositeTypeLayoutPass> {
114  void runOnOperation() override;
115 };
116 } // namespace
117 
118 void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
119  auto module = getOperation();
120  RewritePatternSet patterns(module.getContext());
122  ConversionTarget target(*(module.getContext()));
123  target.addLegalDialect<spirv::SPIRVDialect>();
124  target.addLegalOp<func::FuncOp>();
125  target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
126  [](spirv::GlobalVariableOp op) {
127  return VulkanLayoutUtils::isLegalType(op.getType());
128  });
129 
130  // Change the type for the direct users.
131  target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
132  return VulkanLayoutUtils::isLegalType(op.getPointer().getType());
133  });
134 
135  // Change the type for the indirect users.
136  target.addDynamicallyLegalOp<spirv::AccessChainOp, spirv::LoadOp,
137  spirv::StoreOp>([&](Operation *op) {
138  for (Value operand : op->getOperands()) {
139  auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
140  if (addrOp &&
141  !VulkanLayoutUtils::isLegalType(addrOp.getPointer().getType()))
142  return false;
143  }
144  return true;
145  });
146 
147  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
148  for (auto spirvModule : module.getOps<spirv::ModuleOp>())
149  if (failed(applyFullConversion(spirvModule, target, frozenPatterns)))
150  signalPassFailure();
151 }
static void populateSPIRVLayoutInfoPatterns(RewritePatternSet &patterns)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static bool isLegalType(Type type)
Checks whether a type is legal in terms of Vulkan layout info decoration.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
SPIR-V struct type.
Definition: SPIRVTypes.h:293
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358