MLIR  22.0.0git
LowerABIAttributesPass.cpp
Go to the documentation of this file.
1 //===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to lower attributes that specify the shader ABI
10 // for the functions in the generated SPIR-V module.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
25 
26 namespace mlir {
27 namespace spirv {
28 #define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
29 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
30 } // namespace spirv
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 /// Creates a global variable for an argument based on the ABI info.
36 static spirv::GlobalVariableOp
37 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
38  unsigned argIndex,
40  auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
41  if (!spirvModule)
42  return nullptr;
43 
44  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
45  builder.setInsertionPoint(funcOp.getOperation());
46  std::string varName =
47  funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
48 
49  // Get the type of variable. If this is a scalar/vector type and has an ABI
50  // info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
51  // not it must already be a !spirv.ptr<!spirv.struct<...>>.
52  auto varType = funcOp.getFunctionType().getInput(argIndex);
53  if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
54  auto storageClass = abiInfo.getStorageClass();
55  if (!storageClass)
56  return nullptr;
57  varType =
58  spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
59  }
60  auto varPtrType = cast<spirv::PointerType>(varType);
61  Type pointeeType = varPtrType.getPointeeType();
62 
63  // Images are an opaque type and so we can just return a pointer to an image.
64  // Note that currently only sampled images are supported in the SPIR-V
65  // lowering.
66  if (isa<spirv::SampledImageType>(pointeeType))
67  return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
68  varName, abiInfo.getDescriptorSet(),
69  abiInfo.getBinding());
70 
71  auto varPointeeType = cast<spirv::StructType>(pointeeType);
72 
73  // Set the offset information.
74  varPointeeType =
75  cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
76 
77  if (!varPointeeType)
78  return nullptr;
79 
80  varType =
81  spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
82 
83  return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
84  varName, abiInfo.getDescriptorSet(),
85  abiInfo.getBinding());
86 }
87 
88 /// Gets the global variables that need to be specified as interface variable
89 /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
90 static LogicalResult
91 getInterfaceVariables(spirv::FuncOp funcOp,
92  SmallVectorImpl<Attribute> &interfaceVars) {
93  auto module = funcOp->getParentOfType<spirv::ModuleOp>();
94  if (!module) {
95  return failure();
96  }
97  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
98  spirv::TargetEnv targetEnv(targetEnvAttr);
99 
100  SetVector<Operation *> interfaceVarSet;
101 
102  // TODO: This should in reality traverse the entry function
103  // call graph and collect all the interfaces. For now, just traverse the
104  // instructions in this function.
105  funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
106  auto var =
107  module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
108  // Per SPIR-V spec: "Before version 1.4, the interface's
109  // storage classes are limited to the Input and Output storage classes.
110  // Starting with version 1.4, the interface's storage classes are all
111  // storage classes used in declaring all global variables referenced by the
112  // entry point’s call tree."
113  const spirv::StorageClass storageClass =
114  cast<spirv::PointerType>(var.getType()).getStorageClass();
115  if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) ||
116  (llvm::is_contained(
117  {spirv::StorageClass::Input, spirv::StorageClass::Output},
118  storageClass))) {
119  interfaceVarSet.insert(var.getOperation());
120  }
121  });
122  for (auto &var : interfaceVarSet) {
123  interfaceVars.push_back(SymbolRefAttr::get(
124  funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
125  }
126  return success();
127 }
128 
129 /// Lowers the entry point attribute.
130 static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
131  OpBuilder &builder) {
132  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
133  auto entryPointAttr =
134  funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
135  if (!entryPointAttr) {
136  return failure();
137  }
138 
139  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
140  spirv::TargetEnv targetEnv(targetEnvAttr);
141 
142  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
143  auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
144  builder.setInsertionPointToEnd(spirvModule.getBody());
145 
146  // Adds the spirv.EntryPointOp after collecting all the interface variables
147  // needed.
148  SmallVector<Attribute, 1> interfaceVars;
149  if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
150  return failure();
151  }
152 
153  FailureOr<spirv::ExecutionModel> executionModel =
154  spirv::getExecutionModel(targetEnvAttr);
155  if (failed(executionModel))
156  return funcOp.emitRemark("lower entry point failure: could not select "
157  "execution model based on 'spirv.target_env'");
158 
159  spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
160  interfaceVars);
161 
162  // Specifies the spirv.ExecutionModeOp.
163  if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
164  std::optional<ArrayRef<spirv::Capability>> caps =
165  spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
166  if (!caps || targetEnv.allows(*caps)) {
167  spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
168  spirv::ExecutionMode::LocalSize,
169  workgroupSizeAttr.asArrayRef());
170  // Erase workgroup size.
171  entryPointAttr = spirv::EntryPointABIAttr::get(
172  entryPointAttr.getContext(), DenseI32ArrayAttr(),
173  entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
174  }
175  }
176  if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
177  std::optional<ArrayRef<spirv::Capability>> caps =
178  spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
179  if (!caps || targetEnv.allows(*caps)) {
180  spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
181  spirv::ExecutionMode::SubgroupSize,
182  *subgroupSize);
183  // Erase subgroup size.
184  entryPointAttr = spirv::EntryPointABIAttr::get(
185  entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
186  std::nullopt, entryPointAttr.getTargetWidth());
187  }
188  }
189  if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
190  std::optional<ArrayRef<spirv::Capability>> caps =
191  spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
192  if (!caps || targetEnv.allows(*caps)) {
193  spirv::ExecutionModeOp::create(
194  builder, funcOp.getLoc(), funcOp,
195  spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
196  // Erase target width.
197  entryPointAttr = spirv::EntryPointABIAttr::get(
198  entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
199  entryPointAttr.getSubgroupSize(), std::nullopt);
200  }
201  }
202  if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
203  entryPointAttr.getTargetWidth())
204  funcOp->setAttr(entryPointAttrName, entryPointAttr);
205  else
206  funcOp->removeAttr(entryPointAttrName);
207  return success();
208 }
209 
210 namespace {
211 /// A pattern to convert function signature according to interface variable ABI
212 /// attributes.
213 ///
214 /// Specifically, this pattern creates global variables according to interface
215 /// variable ABI attributes attached to function arguments and converts all
216 /// function argument uses to those global variables. This is necessary because
217 /// Vulkan requires all shader entry points to be of void(void) type.
218 class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
219 public:
221 
222  LogicalResult
223  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
224  ConversionPatternRewriter &rewriter) const override;
225 };
226 
227 /// Pass to implement the ABI information specified as attributes.
228 class LowerABIAttributesPass final
229  : public spirv::impl::SPIRVLowerABIAttributesPassBase<
230  LowerABIAttributesPass> {
231  void runOnOperation() override;
232 };
233 } // namespace
234 
235 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
236  spirv::FuncOp funcOp, OpAdaptor adaptor,
237  ConversionPatternRewriter &rewriter) const {
238  if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
240  // TODO: Non-entry point functions are not handled.
241  return failure();
242  }
243  TypeConverter::SignatureConversion signatureConverter(
244  funcOp.getFunctionType().getNumInputs());
245 
246  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
247  auto indexType = typeConverter.getIndexType();
248 
249  auto attrName = spirv::getInterfaceVarABIAttrName();
250 
251  OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
252  rewriter.setInsertionPointToStart(&funcOp.front());
253 
254  for (const auto &argType :
255  llvm::enumerate(funcOp.getFunctionType().getInputs())) {
256  auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
257  argType.index(), attrName);
258  if (!abiInfo) {
259  // TODO: For non-entry point functions, it should be legal
260  // to pass around scalar/vector values and return a scalar/vector. For now
261  // non-entry point functions are not handled in this ABI lowering and will
262  // produce an error.
263  return failure();
264  }
265  spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
266  rewriter, funcOp, argType.index(), abiInfo);
267  if (!var)
268  return failure();
269 
270  // Insert spirv::AddressOf and spirv::AccessChain operations.
271  Value replacement =
272  spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
273  // Check if the arg is a scalar or vector type. In that case, the value
274  // needs to be loaded into registers.
275  // TODO: This is loading value of the scalar into registers
276  // at the start of the function. It is probably better to do the load just
277  // before the use. There might be multiple loads and currently there is no
278  // easy way to replace all uses with a sequence of operations.
279  if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
280  auto zero =
281  spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
282  auto loadPtr = spirv::AccessChainOp::create(
283  rewriter, funcOp.getLoc(), replacement, zero.getConstant());
284  replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
285  }
286  signatureConverter.remapInput(argType.index(), replacement);
287  }
288  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
289  &signatureConverter)))
290  return failure();
291 
292  // Creates a new function with the update signature.
293  rewriter.modifyOpInPlace(funcOp, [&] {
294  funcOp.setType(
295  rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
296  });
297  return success();
298 }
299 
300 void LowerABIAttributesPass::runOnOperation() {
301  // Uses the signature conversion methodology of the dialect conversion
302  // framework to implement the conversion.
303  spirv::ModuleOp module = getOperation();
304  MLIRContext *context = &getContext();
305 
306  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(module);
307  if (!targetEnvAttr) {
308  module->emitOpError("missing SPIR-V target env attribute");
309  return signalPassFailure();
310  }
311  spirv::TargetEnv targetEnv(targetEnvAttr);
312 
313  SPIRVTypeConverter typeConverter(targetEnv);
314 
315  // Insert a bitcast in the case of a pointer type change.
316  typeConverter.addSourceMaterialization([](OpBuilder &builder,
317  spirv::PointerType type,
318  ValueRange inputs, Location loc) {
319  if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
320  return Value();
321  return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
322  });
323 
324  RewritePatternSet patterns(context);
325  patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
326 
327  ConversionTarget target(*context);
328  // "Legal" function ops should have no interface variable ABI attributes.
329  target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
330  StringRef attrName = spirv::getInterfaceVarABIAttrName();
331  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
332  if (op.getArgAttr(i, attrName))
333  return false;
334  return true;
335  });
336  // All other SPIR-V ops are legal.
337  target.markUnknownOpDynamicallyLegal([](Operation *op) {
338  return op->getDialect()->getNamespace() ==
339  spirv::SPIRVDialect::getDialectNamespace();
340  });
341  if (failed(applyPartialConversion(module, target, std::move(patterns))))
342  return signalPassFailure();
343 
344  // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
345  // attributes.
346  OpBuilder builder(context);
347  SmallVector<spirv::FuncOp, 1> entryPointFns;
348  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
349  module.walk([&](spirv::FuncOp funcOp) {
350  if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
351  entryPointFns.push_back(funcOp);
352  }
353  });
354  for (auto fn : entryPointFns) {
355  if (failed(lowerEntryPointABIAttr(fn, builder))) {
356  return signalPassFailure();
357  }
358  }
359 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static LogicalResult getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
This class describes a specific conversion target.
StringRef getNamespace() const
Definition: Dialect.h:54
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:20
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
constexpr unsigned subgroupSize
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.