MLIR  14.0.0git
LowerABIAttributesPass.cpp
Go to the documentation of this file.
1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
21 #include "llvm/ADT/SetVector.h"
22 
23 using namespace mlir;
24 
25 /// Creates a global variable for an argument based on the ABI info.
26 static spirv::GlobalVariableOp
27 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
28  unsigned argIndex,
30  auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
31  if (!spirvModule)
32  return nullptr;
33 
34  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
35  builder.setInsertionPoint(funcOp.getOperation());
36  std::string varName =
37  funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
38 
39  // Get the type of variable. If this is a scalar/vector type and has an ABI
40  // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
41  // it must already be a !spv.ptr<!spv.struct<...>>.
42  auto varType = funcOp.getType().getInput(argIndex);
43  if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
44  auto storageClass = abiInfo.getStorageClass();
45  if (!storageClass)
46  return nullptr;
47  varType =
48  spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
49  }
50  auto varPtrType = varType.cast<spirv::PointerType>();
51  auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
52 
53  // Set the offset information.
54  varPointeeType =
55  VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
56 
57  if (!varPointeeType)
58  return nullptr;
59 
60  varType =
61  spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
62 
63  return builder.create<spirv::GlobalVariableOp>(
64  funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
65  abiInfo.getBinding());
66 }
67 
68 /// Gets the global variables that need to be specified as interface variable
69 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
70 static LogicalResult
71 getInterfaceVariables(spirv::FuncOp funcOp,
72  SmallVectorImpl<Attribute> &interfaceVars) {
73  auto module = funcOp->getParentOfType<spirv::ModuleOp>();
74  if (!module) {
75  return failure();
76  }
77  SetVector<Operation *> interfaceVarSet;
78 
79  // TODO: This should in reality traverse the entry function
80  // call graph and collect all the interfaces. For now, just traverse the
81  // instructions in this function.
82  funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
83  auto var =
84  module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
85  // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
86  // storage classes are limited to the Input and Output storage classes.
87  // Starting with version 1.4, the interface’s storage classes are all
88  // storage classes used in declaring all global variables referenced by the
89  // entry point’s call tree." We should consider the target environment here.
90  switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
91  case spirv::StorageClass::Input:
92  case spirv::StorageClass::Output:
93  interfaceVarSet.insert(var.getOperation());
94  break;
95  default:
96  break;
97  }
98  });
99  for (auto &var : interfaceVarSet) {
100  interfaceVars.push_back(SymbolRefAttr::get(
101  funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
102  }
103  return success();
104 }
105 
106 /// Lowers the entry point attribute.
107 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
108  OpBuilder &builder) {
109  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
110  auto entryPointAttr =
111  funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
112  if (!entryPointAttr) {
113  return failure();
114  }
115 
116  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
117  auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
118  builder.setInsertionPointToEnd(spirvModule.getBody());
119 
120  // Adds the spv.EntryPointOp after collecting all the interface variables
121  // needed.
122  SmallVector<Attribute, 1> interfaceVars;
123  if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
124  return failure();
125  }
126 
127  spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(funcOp);
128  FailureOr<spirv::ExecutionModel> executionModel =
129  spirv::getExecutionModel(targetEnv);
130  if (failed(executionModel))
131  return funcOp.emitRemark("lower entry point failure: could not select "
132  "execution model based on 'spv.target_env'");
133 
134  builder.create<spirv::EntryPointOp>(
135  funcOp.getLoc(), executionModel.getValue(), funcOp, interfaceVars);
136 
137  // Specifies the spv.ExecutionModeOp.
138  auto localSizeAttr = entryPointAttr.local_size();
139  SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>());
140  builder.create<spirv::ExecutionModeOp>(
141  funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
142  funcOp->removeAttr(entryPointAttrName);
143  return success();
144 }
145 
146 namespace {
147 /// A pattern to convert function signature according to interface variable ABI
148 /// attributes.
149 ///
150 /// Specifically, this pattern creates global variables according to interface
151 /// variable ABI attributes attached to function arguments and converts all
152 /// function argument uses to those global variables. This is necessary because
153 /// Vulkan requires all shader entry points to be of void(void) type.
154 class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
155 public:
157 
159  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
160  ConversionPatternRewriter &rewriter) const override;
161 };
162 
163 /// Pass to implement the ABI information specified as attributes.
164 class LowerABIAttributesPass final
165  : public SPIRVLowerABIAttributesBase<LowerABIAttributesPass> {
166  void runOnOperation() override;
167 };
168 } // namespace
169 
170 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
171  spirv::FuncOp funcOp, OpAdaptor adaptor,
172  ConversionPatternRewriter &rewriter) const {
173  if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
175  // TODO: Non-entry point functions are not handled.
176  return failure();
177  }
178  TypeConverter::SignatureConversion signatureConverter(
179  funcOp.getType().getNumInputs());
180 
181  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
182  auto indexType = typeConverter.getIndexType();
183 
184  auto attrName = spirv::getInterfaceVarABIAttrName();
185  for (const auto &argType : llvm::enumerate(funcOp.getType().getInputs())) {
186  auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
187  argType.index(), attrName);
188  if (!abiInfo) {
189  // TODO: For non-entry point functions, it should be legal
190  // to pass around scalar/vector values and return a scalar/vector. For now
191  // non-entry point functions are not handled in this ABI lowering and will
192  // produce an error.
193  return failure();
194  }
195  spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
196  rewriter, funcOp, argType.index(), abiInfo);
197  if (!var)
198  return failure();
199 
200  OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
201  rewriter.setInsertionPointToStart(&funcOp.front());
202  // Insert spirv::AddressOf and spirv::AccessChain operations.
203  Value replacement =
204  rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
205  // Check if the arg is a scalar or vector type. In that case, the value
206  // needs to be loaded into registers.
207  // TODO: This is loading value of the scalar into registers
208  // at the start of the function. It is probably better to do the load just
209  // before the use. There might be multiple loads and currently there is no
210  // easy way to replace all uses with a sequence of operations.
211  if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
212  auto zero =
213  spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
214  auto loadPtr = rewriter.create<spirv::AccessChainOp>(
215  funcOp.getLoc(), replacement, zero.constant());
216  replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
217  }
218  signatureConverter.remapInput(argType.index(), replacement);
219  }
220  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
221  &signatureConverter)))
222  return failure();
223 
224  // Creates a new function with the update signature.
225  rewriter.updateRootInPlace(funcOp, [&] {
226  funcOp.setType(rewriter.getFunctionType(
227  signatureConverter.getConvertedTypes(), llvm::None));
228  });
229  return success();
230 }
231 
232 void LowerABIAttributesPass::runOnOperation() {
233  // Uses the signature conversion methodology of the dialect conversion
234  // framework to implement the conversion.
235  spirv::ModuleOp module = getOperation();
236  MLIRContext *context = &getContext();
237 
238  spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
239 
240  SPIRVTypeConverter typeConverter(targetEnv);
241 
242  // Insert a bitcast in the case of a pointer type change.
243  typeConverter.addSourceMaterialization([](OpBuilder &builder,
244  spirv::PointerType type,
245  ValueRange inputs, Location loc) {
246  if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
247  return Value();
248  return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
249  });
250 
251  RewritePatternSet patterns(context);
252  patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
253 
254  ConversionTarget target(*context);
255  // "Legal" function ops should have no interface variable ABI attributes.
256  target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
257  StringRef attrName = spirv::getInterfaceVarABIAttrName();
258  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
259  if (op.getArgAttr(i, attrName))
260  return false;
261  return true;
262  });
263  // All other SPIR-V ops are legal.
264  target.markUnknownOpDynamicallyLegal([](Operation *op) {
265  return op->getDialect()->getNamespace() ==
266  spirv::SPIRVDialect::getDialectNamespace();
267  });
268  if (failed(applyPartialConversion(module, target, std::move(patterns))))
269  return signalPassFailure();
270 
271  // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
272  // attributes.
273  OpBuilder builder(context);
274  SmallVector<spirv::FuncOp, 1> entryPointFns;
275  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
276  module.walk([&](spirv::FuncOp funcOp) {
277  if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
278  entryPointFns.push_back(funcOp);
279  }
280  });
281  for (auto fn : entryPointFns) {
282  if (failed(lowerEntryPointABIAttr(fn, builder))) {
283  return signalPassFailure();
284  }
285  }
286 }
287 
288 std::unique_ptr<OperationPass<spirv::ModuleOp>>
290  return std::make_unique<LowerABIAttributesPass>();
291 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
An attribute that specifies the information regarding the interface variable: descriptor set...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:948
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
std::unique_ptr< OperationPass< spirv::ModuleOp > > createLowerABIAttributesPass()
Creates an operation pass that lowers the ABI attributes specified during SPIR-V Lowering.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
uint32_t getDescriptorSet()
Returns descriptor set.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
auto getType() const
uint32_t getBinding()
Returns binding.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:28
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
StringRef getNamespace() const
Definition: Dialect.h:58
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:103
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
SPIR-V struct type.
Definition: SPIRVTypes.h:278
This class implements a pattern rewriter for use with ConversionPatterns.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
This class describes a specific conversion target.
static LogicalResult getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spv...
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:67
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
An attribute that specifies the target version, allowed extensions and capabilities, and resource limits.
Optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
Type conversion from builtin types to SPIR-V types for shader interface.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
U cast() const
Definition: Types.h:250