MLIR 22.0.0git
LowerABIAttributesPass.cpp
Go to the documentation of this file.
1//===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to lower attributes that specify the shader ABI
10// for the functions in the generated SPIR-V module.
11//
12//===----------------------------------------------------------------------===//
13
15
25#include "llvm/Support/FormatVariadic.h"
26
27namespace mlir {
28namespace spirv {
29#define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
31} // namespace spirv
32} // namespace mlir
33
34using namespace mlir;
35
36/// Creates a global variable for an argument based on the ABI info.
37static spirv::GlobalVariableOp
38createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
39 unsigned argIndex,
41 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
42 if (!spirvModule)
43 return nullptr;
44
45 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
46 builder.setInsertionPoint(funcOp.getOperation());
47 std::string varName =
48 funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
49
50 // Get the type of variable. If this is a scalar/vector type and has an ABI
51 // info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
52 // not it must already be a !spirv.ptr<!spirv.struct<...>>.
53 auto varType = funcOp.getFunctionType().getInput(argIndex);
54 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
55 auto storageClass = abiInfo.getStorageClass();
56 if (!storageClass)
57 return nullptr;
58 varType =
59 spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
60 }
61 auto varPtrType = cast<spirv::PointerType>(varType);
62 Type pointeeType = varPtrType.getPointeeType();
63
64 // Images are an opaque type and so we can just return a pointer to an image.
65 // Note that currently only sampled images are supported in the SPIR-V
66 // lowering.
67 if (isa<spirv::SampledImageType>(pointeeType))
68 return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
69 varName, abiInfo.getDescriptorSet(),
70 abiInfo.getBinding());
71
72 auto varPointeeType = cast<spirv::StructType>(pointeeType);
73
74 // Set the offset information.
75 varPointeeType =
76 cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
77
78 if (!varPointeeType)
79 return nullptr;
80
81 varType =
82 spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
83
84 return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
85 varName, abiInfo.getDescriptorSet(),
86 abiInfo.getBinding());
87}
88
89/// Creates a global variable for an argument or result based on the ABI info.
90static spirv::GlobalVariableOp
91createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
92 unsigned index, bool isArg,
94 auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
95 if (!spirvModule)
96 return nullptr;
97
98 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
99 builder.setInsertionPoint(graphOp.getOperation());
100 std::string varName = llvm::formatv("{}_{}_{}", graphOp.getName(),
101 isArg ? "arg" : "res", index);
102
103 Type varType = isArg ? graphOp.getFunctionType().getInput(index)
104 : graphOp.getFunctionType().getResult(index);
105
106 auto pointerType = spirv::PointerType::get(
107 varType,
108 abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
109
110 return spirv::GlobalVariableOp::create(builder, graphOp.getLoc(), pointerType,
111 varName, abiInfo.getDescriptorSet(),
112 abiInfo.getBinding());
113}
114
115/// Gets the global variables that need to be specified as interface variable
116/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
117static LogicalResult
118getInterfaceVariables(mlir::FunctionOpInterface funcOp,
119 SmallVectorImpl<Attribute> &interfaceVars) {
120 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
121 if (!module) {
122 return failure();
123 }
124 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
125 spirv::TargetEnv targetEnv(targetEnvAttr);
126
127 SetVector<Operation *> interfaceVarSet;
128
129 // TODO: This should in reality traverse the entry function
130 // call graph and collect all the interfaces. For now, just traverse the
131 // instructions in this function.
132 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
133 auto var =
134 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
135 // Per SPIR-V spec: "Before version 1.4, the interface's
136 // storage classes are limited to the Input and Output storage classes.
137 // Starting with version 1.4, the interface's storage classes are all
138 // storage classes used in declaring all global variables referenced by the
139 // entry point’s call tree."
140 const spirv::StorageClass storageClass =
141 cast<spirv::PointerType>(var.getType()).getStorageClass();
142 if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) ||
143 (llvm::is_contained(
144 {spirv::StorageClass::Input, spirv::StorageClass::Output},
145 storageClass))) {
146 interfaceVarSet.insert(var.getOperation());
147 }
148 });
149 for (auto &var : interfaceVarSet) {
150 interfaceVars.push_back(SymbolRefAttr::get(
151 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
152 }
153 return success();
154}
155
156/// Lowers the entry point attribute.
157static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
158 OpBuilder &builder) {
159 auto entryPointAttrName = spirv::getEntryPointABIAttrName();
160 auto entryPointAttr =
161 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
162 if (!entryPointAttr) {
163 return failure();
164 }
165
166 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
167 spirv::TargetEnv targetEnv(targetEnvAttr);
168
169 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
170 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
171 builder.setInsertionPointToEnd(spirvModule.getBody());
172
173 // Adds the spirv.EntryPointOp after collecting all the interface variables
174 // needed.
175 SmallVector<Attribute, 1> interfaceVars;
176 if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
177 return failure();
178 }
179
180 FailureOr<spirv::ExecutionModel> executionModel =
181 spirv::getExecutionModel(targetEnvAttr);
182 if (failed(executionModel))
183 return funcOp.emitRemark("lower entry point failure: could not select "
184 "execution model based on 'spirv.target_env'");
185
186 spirv::EntryPointOp::create(builder, funcOp.getLoc(), *executionModel, funcOp,
187 interfaceVars);
189 // Specifies the spirv.ExecutionModeOp.
190 if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
191 std::optional<ArrayRef<spirv::Capability>> caps =
192 spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
193 if (!caps || targetEnv.allows(*caps)) {
194 spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
195 spirv::ExecutionMode::LocalSize,
196 workgroupSizeAttr.asArrayRef());
197 // Erase workgroup size.
198 entryPointAttr = spirv::EntryPointABIAttr::get(
199 entryPointAttr.getContext(), DenseI32ArrayAttr(),
200 entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
201 }
203 if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
204 std::optional<ArrayRef<spirv::Capability>> caps =
205 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
206 if (!caps || targetEnv.allows(*caps)) {
207 spirv::ExecutionModeOp::create(builder, funcOp.getLoc(), funcOp,
208 spirv::ExecutionMode::SubgroupSize,
209 *subgroupSize);
210 // Erase subgroup size.
211 entryPointAttr = spirv::EntryPointABIAttr::get(
212 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
213 std::nullopt, entryPointAttr.getTargetWidth());
214 }
215 }
216 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
217 std::optional<ArrayRef<spirv::Capability>> caps =
218 spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
219 if (!caps || targetEnv.allows(*caps)) {
220 spirv::ExecutionModeOp::create(
221 builder, funcOp.getLoc(), funcOp,
222 spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
223 // Erase target width.
224 entryPointAttr = spirv::EntryPointABIAttr::get(
225 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
226 entryPointAttr.getSubgroupSize(), std::nullopt);
227 }
228 }
229 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
230 entryPointAttr.getTargetWidth())
231 funcOp->setAttr(entryPointAttrName, entryPointAttr);
232 else
233 funcOp->removeAttr(entryPointAttrName);
234 return success();
235}
237namespace {
238/// A pattern to convert function signature according to interface variable ABI
239/// attributes.
240///
241/// Specifically, this pattern creates global variables according to interface
242/// variable ABI attributes attached to function arguments and converts all
243/// function argument uses to those global variables. This is necessary because
244/// Vulkan requires all shader entry points to be of void(void) type.
245class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
246public:
247 using Base::Base;
248
249 LogicalResult
250 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
251 ConversionPatternRewriter &rewriter) const override;
252};
253
254/// A pattern to convert graph signature according to interface variable ABI
255/// attributes.
256///
257/// Specifically, this pattern creates global variables according to interface
258/// variable ABI attributes attached to graph arguments and results.
259class ProcessGraphInterfaceVarABI final
260 : public OpConversionPattern<spirv::GraphARMOp> {
261public:
262 using OpConversionPattern::OpConversionPattern;
263
264 LogicalResult
265 matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
266 ConversionPatternRewriter &rewriter) const override;
267};
268
269/// Pass to implement the ABI information specified as attributes.
270class LowerABIAttributesPass final
272 LowerABIAttributesPass> {
273 void runOnOperation() override;
274};
275} // namespace
276
277LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
278 spirv::FuncOp funcOp, OpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter) const {
280 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
282 // TODO: Non-entry point functions are not handled.
283 return failure();
284 }
285 TypeConverter::SignatureConversion signatureConverter(
286 funcOp.getFunctionType().getNumInputs());
287
288 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
289 auto indexType = typeConverter.getIndexType();
290
291 auto attrName = spirv::getInterfaceVarABIAttrName();
292
293 OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
294 rewriter.setInsertionPointToStart(&funcOp.front());
295
296 for (const auto &argType :
297 llvm::enumerate(funcOp.getFunctionType().getInputs())) {
298 auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
299 argType.index(), attrName);
300 if (!abiInfo) {
301 // TODO: For non-entry point functions, it should be legal
302 // to pass around scalar/vector values and return a scalar/vector. For now
303 // non-entry point functions are not handled in this ABI lowering and will
304 // produce an error.
305 return failure();
306 }
307 spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
308 rewriter, funcOp, argType.index(), abiInfo);
309 if (!var)
310 return failure();
311
312 // Insert spirv::AddressOf and spirv::AccessChain operations.
313 Value replacement =
314 spirv::AddressOfOp::create(rewriter, funcOp.getLoc(), var);
315 // Check if the arg is a scalar or vector type. In that case, the value
316 // needs to be loaded into registers.
317 // TODO: This is loading value of the scalar into registers
318 // at the start of the function. It is probably better to do the load just
319 // before the use. There might be multiple loads and currently there is no
320 // easy way to replace all uses with a sequence of operations.
321 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
322 auto zero =
323 spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
324 auto loadPtr = spirv::AccessChainOp::create(
325 rewriter, funcOp.getLoc(), replacement, zero.getConstant());
326 replacement = spirv::LoadOp::create(rewriter, funcOp.getLoc(), loadPtr);
327 }
328 signatureConverter.remapInput(argType.index(), replacement);
329 }
330 if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
331 &signatureConverter)))
332 return failure();
333
334 // Creates a new function with the update signature.
335 rewriter.modifyOpInPlace(funcOp, [&] {
336 funcOp.setType(
337 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
338 });
339 return success();
340}
341
342LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
343 spirv::GraphARMOp graphOp, OpAdaptor adaptor,
344 ConversionPatternRewriter &rewriter) const {
345 // Non-entry point graphs are not handled.
346 if (!graphOp.getEntryPoint().value_or(false))
347 return failure();
348
349 TypeConverter::SignatureConversion signatureConverter(
350 graphOp.getFunctionType().getNumInputs());
351
352 StringRef attrName = spirv::getInterfaceVarABIAttrName();
353 SmallVector<Attribute, 4> interfaceVars;
354
355 // Convert arguments.
356 unsigned numInputs = graphOp.getFunctionType().getNumInputs();
357 unsigned numResults = graphOp.getFunctionType().getNumResults();
358 for (unsigned index = 0; index < numInputs; ++index) {
359 auto abiInfo =
360 graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
361 if (!abiInfo)
362 return failure();
363 spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
364 rewriter, graphOp, index, true, abiInfo);
365 if (!var)
366 return failure();
367 interfaceVars.push_back(
368 SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
369 }
370
371 for (unsigned index = 0; index < numResults; ++index) {
372 auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
373 index, attrName);
374 if (!abiInfo)
375 return failure();
376 spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
377 rewriter, graphOp, index, false, abiInfo);
378 if (!var)
379 return failure();
380 interfaceVars.push_back(
381 SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
382 }
383
384 // Update graph signature.
385 rewriter.modifyOpInPlace(graphOp, [&] {
386 for (unsigned index = 0; index < numInputs; ++index) {
387 graphOp.removeArgAttr(index, attrName);
388 }
389 for (unsigned index = 0; index < numResults; ++index) {
390 graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
391 }
392 });
393
394 spirv::GraphEntryPointARMOp::create(rewriter, graphOp.getLoc(), graphOp,
395 interfaceVars);
396 return success();
397}
398
399void LowerABIAttributesPass::runOnOperation() {
400 // Uses the signature conversion methodology of the dialect conversion
401 // framework to implement the conversion.
402 spirv::ModuleOp module = getOperation();
403 MLIRContext *context = &getContext();
404
405 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(module);
406 if (!targetEnvAttr) {
407 module->emitOpError("missing SPIR-V target env attribute");
408 return signalPassFailure();
409 }
410 spirv::TargetEnv targetEnv(targetEnvAttr);
411
412 SPIRVTypeConverter typeConverter(targetEnv);
413
414 // Insert a bitcast in the case of a pointer type change.
415 typeConverter.addSourceMaterialization([](OpBuilder &builder,
416 spirv::PointerType type,
417 ValueRange inputs, Location loc) {
418 if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
419 return Value();
420 return spirv::BitcastOp::create(builder, loc, type, inputs[0]).getResult();
421 });
422
423 RewritePatternSet patterns(context);
424 patterns.add<ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
425 typeConverter, context);
426
427 ConversionTarget target(*context);
428 // "Legal" function ops should have no interface variable ABI attributes.
429 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
430 StringRef attrName = spirv::getInterfaceVarABIAttrName();
431 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
432 if (op.getArgAttr(i, attrName))
433 return false;
434 return true;
435 });
436 target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
437 StringRef attrName = spirv::getInterfaceVarABIAttrName();
438 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
439 if (op.getArgAttr(i, attrName))
440 return false;
441 for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
442 if (op.getResultAttr(i, attrName))
443 return false;
444 return true;
445 });
446
447 // All other SPIR-V ops are legal.
448 target.markUnknownOpDynamicallyLegal([](Operation *op) {
449 return op->getDialect()->getNamespace() ==
450 spirv::SPIRVDialect::getDialectNamespace();
451 });
452 if (failed(applyPartialConversion(module, target, std::move(patterns))))
453 return signalPassFailure();
454
455 // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
456 // attributes.
457 OpBuilder builder(context);
458 SmallVector<spirv::FuncOp, 1> entryPointFns;
459 auto entryPointAttrName = spirv::getEntryPointABIAttrName();
460 module.walk([&](spirv::FuncOp funcOp) {
461 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
462 entryPointFns.push_back(funcOp);
463 }
464 });
465 for (auto fn : entryPointFns) {
466 if (failed(lowerEntryPointABIAttr(fn, builder))) {
467 return signalPassFailure();
468 }
469 }
470}
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static spirv::GlobalVariableOp createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, unsigned argIndex, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument based on the ABI info.
static spirv::GlobalVariableOp createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp, unsigned index, bool isArg, spirv::InterfaceVarABIAttr abiInfo)
Creates a global variable for an argument or result based on the ABI info.
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, OpBuilder &builder)
Lowers the entry point attribute.
static LogicalResult getInterfaceVariables(mlir::FunctionOpInterface funcOp, SmallVectorImpl< Attribute > &interfaceVars)
Gets the global variables that need to be specified as interface variable with an spirv....
StringRef getNamespace() const
Definition Dialect.h:54
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
uint32_t getDescriptorSet()
Returns descriptor set.
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static PointerType get(Type pointeeType, StorageClass storageClass)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
TargetEnvAttr lookupTargetEnv(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr