MLIR  19.0.0git
LowerABIAttributesPass.cpp
Go to the documentation of this file.
1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
25 #include "llvm/ADT/SetVector.h"
26 
27 namespace mlir {
28 namespace spirv {
29 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
31 } // namespace spirv
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 /// Creates a global variable for an argument based on the ABI info.
37 static spirv::GlobalVariableOp
38 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
39  unsigned argIndex,
41  auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
42  if (!spirvModule)
43  return nullptr;
44 
45  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
46  builder.setInsertionPoint(funcOp.getOperation());
47  std::string varName =
48  funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
49 
50  // Get the type of variable. If this is a scalar/vector type and has an ABI
51  // info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
52  // not it must already be a !spirv.ptr<!spirv.struct<...>>.
53  auto varType = funcOp.getFunctionType().getInput(argIndex);
54  if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
55  auto storageClass = abiInfo.getStorageClass();
56  if (!storageClass)
57  return nullptr;
58  varType =
59  spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
60  }
61  auto varPtrType = cast<spirv::PointerType>(varType);
62  auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
63 
64  // Set the offset information.
65  varPointeeType =
66  cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
67 
68  if (!varPointeeType)
69  return nullptr;
70 
71  varType =
72  spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
73 
74  return builder.create<spirv::GlobalVariableOp>(
75  funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
76  abiInfo.getBinding());
77 }
78 
79 /// Gets the global variables that need to be specified as interface variable
80 /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
81 static LogicalResult
82 getInterfaceVariables(spirv::FuncOp funcOp,
83  SmallVectorImpl<Attribute> &interfaceVars) {
84  auto module = funcOp->getParentOfType<spirv::ModuleOp>();
85  if (!module) {
86  return failure();
87  }
88  SetVector<Operation *> interfaceVarSet;
89 
90  // TODO: This should in reality traverse the entry function
91  // call graph and collect all the interfaces. For now, just traverse the
92  // instructions in this function.
93  funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
94  auto var =
95  module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
96  // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
97  // storage classes are limited to the Input and Output storage classes.
98  // Starting with version 1.4, the interface’s storage classes are all
99  // storage classes used in declaring all global variables referenced by the
100  // entry point’s call tree." We should consider the target environment here.
101  switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
102  case spirv::StorageClass::Input:
103  case spirv::StorageClass::Output:
104  interfaceVarSet.insert(var.getOperation());
105  break;
106  default:
107  break;
108  }
109  });
110  for (auto &var : interfaceVarSet) {
111  interfaceVars.push_back(SymbolRefAttr::get(
112  funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
113  }
114  return success();
115 }
116 
117 /// Lowers the entry point attribute.
118 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
119  OpBuilder &builder) {
120  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
121  auto entryPointAttr =
122  funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
123  if (!entryPointAttr) {
124  return failure();
125  }
126 
127  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
128  auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
129  builder.setInsertionPointToEnd(spirvModule.getBody());
130 
131  // Adds the spirv.EntryPointOp after collecting all the interface variables
132  // needed.
133  SmallVector<Attribute, 1> interfaceVars;
134  if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
135  return failure();
136  }
137 
138  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
139  spirv::TargetEnv targetEnv(targetEnvAttr);
140  FailureOr<spirv::ExecutionModel> executionModel =
141  spirv::getExecutionModel(targetEnvAttr);
142  if (failed(executionModel))
143  return funcOp.emitRemark("lower entry point failure: could not select "
144  "execution model based on 'spirv.target_env'");
145 
146  builder.create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
147  interfaceVars);
148 
149  // Specifies the spirv.ExecutionModeOp.
150  if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
151  std::optional<ArrayRef<spirv::Capability>> caps =
152  spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
153  if (!caps || targetEnv.allows(*caps)) {
154  builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
155  spirv::ExecutionMode::LocalSize,
156  workgroupSizeAttr.asArrayRef());
157  // Erase workgroup size.
158  entryPointAttr = spirv::EntryPointABIAttr::get(
159  entryPointAttr.getContext(), DenseI32ArrayAttr(),
160  entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
161  }
162  }
163  if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
164  std::optional<ArrayRef<spirv::Capability>> caps =
165  spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
166  if (!caps || targetEnv.allows(*caps)) {
167  builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
168  spirv::ExecutionMode::SubgroupSize,
169  *subgroupSize);
170  // Erase subgroup size.
171  entryPointAttr = spirv::EntryPointABIAttr::get(
172  entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
173  std::nullopt, entryPointAttr.getTargetWidth());
174  }
175  }
176  if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
177  std::optional<ArrayRef<spirv::Capability>> caps =
178  spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
179  if (!caps || targetEnv.allows(*caps)) {
180  builder.create<spirv::ExecutionModeOp>(
181  funcOp.getLoc(), funcOp,
182  spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
183  // Erase target width.
184  entryPointAttr = spirv::EntryPointABIAttr::get(
185  entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
186  entryPointAttr.getSubgroupSize(), std::nullopt);
187  }
188  }
189  if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
190  entryPointAttr.getTargetWidth())
191  funcOp->setAttr(entryPointAttrName, entryPointAttr);
192  else
193  funcOp->removeAttr(entryPointAttrName);
194  return success();
195 }
196 
197 namespace {
198 /// A pattern to convert function signature according to interface variable ABI
199 /// attributes.
200 ///
201 /// Specifically, this pattern creates global variables according to interface
202 /// variable ABI attributes attached to function arguments and converts all
203 /// function argument uses to those global variables. This is necessary because
204 /// Vulkan requires all shader entry points to be of void(void) type.
205 class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
206 public:
208 
210  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
211  ConversionPatternRewriter &rewriter) const override;
212 };
213 
214 /// Pass to implement the ABI information specified as attributes.
215 class LowerABIAttributesPass final
216  : public spirv::impl::SPIRVLowerABIAttributesPassBase<
217  LowerABIAttributesPass> {
218  void runOnOperation() override;
219 };
220 } // namespace
221 
222 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
223  spirv::FuncOp funcOp, OpAdaptor adaptor,
224  ConversionPatternRewriter &rewriter) const {
225  if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
227  // TODO: Non-entry point functions are not handled.
228  return failure();
229  }
230  TypeConverter::SignatureConversion signatureConverter(
231  funcOp.getFunctionType().getNumInputs());
232 
233  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
234  auto indexType = typeConverter.getIndexType();
235 
236  auto attrName = spirv::getInterfaceVarABIAttrName();
237  for (const auto &argType :
238  llvm::enumerate(funcOp.getFunctionType().getInputs())) {
239  auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
240  argType.index(), attrName);
241  if (!abiInfo) {
242  // TODO: For non-entry point functions, it should be legal
243  // to pass around scalar/vector values and return a scalar/vector. For now
244  // non-entry point functions are not handled in this ABI lowering and will
245  // produce an error.
246  return failure();
247  }
248  spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
249  rewriter, funcOp, argType.index(), abiInfo);
250  if (!var)
251  return failure();
252 
253  OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
254  rewriter.setInsertionPointToStart(&funcOp.front());
255  // Insert spirv::AddressOf and spirv::AccessChain operations.
256  Value replacement =
257  rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
258  // Check if the arg is a scalar or vector type. In that case, the value
259  // needs to be loaded into registers.
260  // TODO: This is loading value of the scalar into registers
261  // at the start of the function. It is probably better to do the load just
262  // before the use. There might be multiple loads and currently there is no
263  // easy way to replace all uses with a sequence of operations.
264  if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
265  auto zero =
266  spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
267  auto loadPtr = rewriter.create<spirv::AccessChainOp>(
268  funcOp.getLoc(), replacement, zero.getConstant());
269  replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
270  }
271  signatureConverter.remapInput(argType.index(), replacement);
272  }
273  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
274  &signatureConverter)))
275  return failure();
276 
277  // Creates a new function with the update signature.
278  rewriter.modifyOpInPlace(funcOp, [&] {
279  funcOp.setType(rewriter.getFunctionType(
280  signatureConverter.getConvertedTypes(), std::nullopt));
281  });
282  return success();
283 }
284 
285 void LowerABIAttributesPass::runOnOperation() {
286  // Uses the signature conversion methodology of the dialect conversion
287  // framework to implement the conversion.
288  spirv::ModuleOp module = getOperation();
289  MLIRContext *context = &getContext();
290 
291  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(module);
292  if (!targetEnvAttr) {
293  module->emitOpError("missing SPIR-V target env attribute");
294  return signalPassFailure();
295  }
296  spirv::TargetEnv targetEnv(targetEnvAttr);
297 
298  SPIRVTypeConverter typeConverter(targetEnv);
299 
300  // Insert a bitcast in the case of a pointer type change.
301  typeConverter.addSourceMaterialization([](OpBuilder &builder,
302  spirv::PointerType type,
303  ValueRange inputs, Location loc) {
304  if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
305  return Value();
306  return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
307  });
308 
309  RewritePatternSet patterns(context);
310  patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
311 
312  ConversionTarget target(*context);
313  // "Legal" function ops should have no interface variable ABI attributes.
314  target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
315  StringRef attrName = spirv::getInterfaceVarABIAttrName();
316  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
317  if (op.getArgAttr(i, attrName))
318  return false;
319  return true;
320  });
321  // All other SPIR-V ops are legal.
322  target.markUnknownOpDynamicallyLegal([](Operation *op) {
323  return op->getDialect()->getNamespace() ==
324  spirv::SPIRVDialect::getDialectNamespace();
325  });
326  if (failed(applyPartialConversion(module, target, std::move(patterns))))
327  return signalPassFailure();
328 
329  // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
330  // attributes.
331  OpBuilder builder(context);
332  SmallVector<spirv::FuncOp, 1> entryPointFns;
333  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
334  module.walk([&](spirv::FuncOp funcOp) {
335  if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
336  entryPointFns.push_back(funcOp);
337  }
338  });
339  for (auto fn : entryPointFns) {
340  if (failed(lowerEntryPointABIAttr(fn, builder))) {
341  return signalPassFailure();
342  }
343  }
344 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static LogicalResult getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
This class describes a specific conversion target.
StringRef getNamespace() const
Definition: Dialect.h:57
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
bool allows(Capability) const
Returns true if the given capability is allowed.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26