MLIR  16.0.0git
LowerABIAttributesPass.cpp
Go to the documentation of this file.
1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
21 #include "llvm/ADT/SetVector.h"
22 
23 using namespace mlir;
24 
25 /// Creates a global variable for an argument based on the ABI info.
26 static spirv::GlobalVariableOp
27 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
28  unsigned argIndex,
30  auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
31  if (!spirvModule)
32  return nullptr;
33 
34  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
35  builder.setInsertionPoint(funcOp.getOperation());
36  std::string varName =
37  funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
38 
39  // Get the type of variable. If this is a scalar/vector type and has an ABI
40  // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
41  // it must already be a !spv.ptr<!spv.struct<...>>.
42  auto varType = funcOp.getFunctionType().getInput(argIndex);
43  if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
44  auto storageClass = abiInfo.getStorageClass();
45  if (!storageClass)
46  return nullptr;
47  varType =
48  spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
49  }
50  auto varPtrType = varType.cast<spirv::PointerType>();
51  auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
52 
53  // Set the offset information.
54  varPointeeType =
55  VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
56 
57  if (!varPointeeType)
58  return nullptr;
59 
60  varType =
61  spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
62 
63  return builder.create<spirv::GlobalVariableOp>(
64  funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
65  abiInfo.getBinding());
66 }
67 
68 /// Gets the global variables that need to be specified as interface variable
69 /// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
70 static LogicalResult
71 getInterfaceVariables(spirv::FuncOp funcOp,
72  SmallVectorImpl<Attribute> &interfaceVars) {
73  auto module = funcOp->getParentOfType<spirv::ModuleOp>();
74  if (!module) {
75  return failure();
76  }
77  SetVector<Operation *> interfaceVarSet;
78 
79  // TODO: This should in reality traverse the entry function
80  // call graph and collect all the interfaces. For now, just traverse the
81  // instructions in this function.
82  funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
83  auto var =
84  module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
85  // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
86  // storage classes are limited to the Input and Output storage classes.
87  // Starting with version 1.4, the interface’s storage classes are all
88  // storage classes used in declaring all global variables referenced by the
89  // entry point’s call tree." We should consider the target environment here.
90  switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
91  case spirv::StorageClass::Input:
92  case spirv::StorageClass::Output:
93  interfaceVarSet.insert(var.getOperation());
94  break;
95  default:
96  break;
97  }
98  });
99  for (auto &var : interfaceVarSet) {
100  interfaceVars.push_back(SymbolRefAttr::get(
101  funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
102  }
103  return success();
104 }
105 
106 /// Lowers the entry point attribute.
107 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
108  OpBuilder &builder) {
109  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
110  auto entryPointAttr =
111  funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
112  if (!entryPointAttr) {
113  return failure();
114  }
115 
116  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
117  auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
118  builder.setInsertionPointToEnd(spirvModule.getBody());
119 
120  // Adds the spv.EntryPointOp after collecting all the interface variables
121  // needed.
122  SmallVector<Attribute, 1> interfaceVars;
123  if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
124  return failure();
125  }
126 
127  spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(funcOp);
128  FailureOr<spirv::ExecutionModel> executionModel =
129  spirv::getExecutionModel(targetEnv);
130  if (failed(executionModel))
131  return funcOp.emitRemark("lower entry point failure: could not select "
132  "execution model based on 'spv.target_env'");
133 
134  builder.create<spirv::EntryPointOp>(funcOp.getLoc(), executionModel.value(),
135  funcOp, interfaceVars);
136 
137  // Specifies the spv.ExecutionModeOp.
138  auto localSizeAttr = entryPointAttr.getLocalSize();
139  if (localSizeAttr) {
140  auto values = localSizeAttr.getValues<int32_t>();
141  SmallVector<int32_t, 3> localSize(values);
142  builder.create<spirv::ExecutionModeOp>(
143  funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
144  funcOp->removeAttr(entryPointAttrName);
145  }
146  return success();
147 }
148 
149 namespace {
150 /// A pattern to convert function signature according to interface variable ABI
151 /// attributes.
152 ///
153 /// Specifically, this pattern creates global variables according to interface
154 /// variable ABI attributes attached to function arguments and converts all
155 /// function argument uses to those global variables. This is necessary because
156 /// Vulkan requires all shader entry points to be of void(void) type.
157 class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
158 public:
160 
162  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
163  ConversionPatternRewriter &rewriter) const override;
164 };
165 
166 /// Pass to implement the ABI information specified as attributes.
167 class LowerABIAttributesPass final
168  : public SPIRVLowerABIAttributesBase<LowerABIAttributesPass> {
169  void runOnOperation() override;
170 };
171 } // namespace
172 
173 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
174  spirv::FuncOp funcOp, OpAdaptor adaptor,
175  ConversionPatternRewriter &rewriter) const {
176  if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
178  // TODO: Non-entry point functions are not handled.
179  return failure();
180  }
181  TypeConverter::SignatureConversion signatureConverter(
182  funcOp.getFunctionType().getNumInputs());
183 
184  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
185  auto indexType = typeConverter.getIndexType();
186 
187  auto attrName = spirv::getInterfaceVarABIAttrName();
188  for (const auto &argType :
189  llvm::enumerate(funcOp.getFunctionType().getInputs())) {
190  auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
191  argType.index(), attrName);
192  if (!abiInfo) {
193  // TODO: For non-entry point functions, it should be legal
194  // to pass around scalar/vector values and return a scalar/vector. For now
195  // non-entry point functions are not handled in this ABI lowering and will
196  // produce an error.
197  return failure();
198  }
199  spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
200  rewriter, funcOp, argType.index(), abiInfo);
201  if (!var)
202  return failure();
203 
204  OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
205  rewriter.setInsertionPointToStart(&funcOp.front());
206  // Insert spirv::AddressOf and spirv::AccessChain operations.
207  Value replacement =
208  rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
209  // Check if the arg is a scalar or vector type. In that case, the value
210  // needs to be loaded into registers.
211  // TODO: This is loading value of the scalar into registers
212  // at the start of the function. It is probably better to do the load just
213  // before the use. There might be multiple loads and currently there is no
214  // easy way to replace all uses with a sequence of operations.
215  if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
216  auto zero =
217  spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
218  auto loadPtr = rewriter.create<spirv::AccessChainOp>(
219  funcOp.getLoc(), replacement, zero.constant());
220  replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
221  }
222  signatureConverter.remapInput(argType.index(), replacement);
223  }
224  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
225  &signatureConverter)))
226  return failure();
227 
228  // Creates a new function with the update signature.
229  rewriter.updateRootInPlace(funcOp, [&] {
230  funcOp.setType(rewriter.getFunctionType(
231  signatureConverter.getConvertedTypes(), llvm::None));
232  });
233  return success();
234 }
235 
236 void LowerABIAttributesPass::runOnOperation() {
237  // Uses the signature conversion methodology of the dialect conversion
238  // framework to implement the conversion.
239  spirv::ModuleOp module = getOperation();
240  MLIRContext *context = &getContext();
241 
242  spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
243 
244  SPIRVTypeConverter typeConverter(targetEnv);
245 
246  // Insert a bitcast in the case of a pointer type change.
247  typeConverter.addSourceMaterialization([](OpBuilder &builder,
248  spirv::PointerType type,
249  ValueRange inputs, Location loc) {
250  if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
251  return Value();
252  return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
253  });
254 
255  RewritePatternSet patterns(context);
256  patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
257 
258  ConversionTarget target(*context);
259  // "Legal" function ops should have no interface variable ABI attributes.
260  target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
261  StringRef attrName = spirv::getInterfaceVarABIAttrName();
262  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
263  if (op.getArgAttr(i, attrName))
264  return false;
265  return true;
266  });
267  // All other SPIR-V ops are legal.
268  target.markUnknownOpDynamicallyLegal([](Operation *op) {
269  return op->getDialect()->getNamespace() ==
270  spirv::SPIRVDialect::getDialectNamespace();
271  });
272  if (failed(applyPartialConversion(module, target, std::move(patterns))))
273  return signalPassFailure();
274 
275  // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
276  // attributes.
277  OpBuilder builder(context);
278  SmallVector<spirv::FuncOp, 1> entryPointFns;
279  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
280  module.walk([&](spirv::FuncOp funcOp) {
281  if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
282  entryPointFns.push_back(funcOp);
283  }
284  });
285  for (auto fn : entryPointFns) {
286  if (failed(lowerEntryPointABIAttr(fn, builder))) {
287  return signalPassFailure();
288  }
289  }
290 }
291 
292 std::unique_ptr<OperationPass<spirv::ModuleOp>>
294  return std::make_unique<LowerABIAttributesPass>();
295 }
Include the generated interface declarations.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:404
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
An attribute that specifies the information regarding the interface variable: descriptor set...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:961
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
std::unique_ptr< OperationPass< spirv::ModuleOp > > createLowerABIAttributesPass()
Creates an operation pass that lowers the ABI attributes specified during SPIR-V Lowering.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:407
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
uint32_t getDescriptorSet()
Returns descriptor set.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
uint32_t getBinding()
Returns binding.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:28
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:400
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
StringRef getNamespace() const
Definition: Dialect.h:57
type_range getType() const
Definition: ValueRange.cpp:46
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:151
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
SPIR-V struct type.
Definition: SPIRVTypes.h:278
This class implements a pattern rewriter for use with ConversionPatterns.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:382
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
This class describes a specific conversion target.
static LogicalResult getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spv...
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:67
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
An attribute that specifies the target version, allowed extensions and capabilities, and resource limits.
Optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
Type conversion from builtin types to SPIR-V types for shader interface.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
U cast() const
Definition: Types.h:278