20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "spirv-serialization"
46 bool skipHeader =
false,
BlockRange skipBlocks = {}) {
47 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
48 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
50 for (
Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
51 if (skipHeader && block == headerBlock)
53 if (
failed(blockHandler(block)))
61 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
63 prepareConstant(op.
getLoc(), op.getType(), op.getValue())) {
70 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
71 if (
auto resultID = prepareConstantScalar(op.
getLoc(), op.getDefaultValue(),
74 if (
auto specID = op->
getAttrOfType<IntegerAttr>(
"spec_id")) {
75 auto val =
static_cast<uint32_t
>(specID.getInt());
76 if (
failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
80 specConstIDMap[op.getSymName()] = resultID;
81 return processName(resultID, op.getSymName());
87 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
89 if (
failed(processType(op.
getLoc(), op.getType(), typeID))) {
93 auto resultID = getNextID();
96 operands.push_back(typeID);
97 operands.push_back(resultID);
99 auto constituents = op.getConstituents();
101 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
102 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
104 auto constituentName = constituent.getValue();
105 auto constituentID = getSpecConstID(constituentName);
107 if (!constituentID) {
108 return op.
emitError(
"unknown result <id> for specialization constant ")
112 operands.push_back(constituentID);
116 spirv::Opcode::OpSpecConstantComposite, operands);
117 specConstIDMap[op.getSymName()] = resultID;
119 return processName(resultID, op.getSymName());
123 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
125 if (
failed(processType(op.
getLoc(), op.getType(), typeID))) {
129 auto resultID = getNextID();
132 operands.push_back(typeID);
133 operands.push_back(resultID);
138 std::string enclosedOpName;
139 llvm::raw_string_ostream rss(enclosedOpName);
141 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
143 if (!enclosedOpcode) {
144 op.
emitError(
"Couldn't find op code for op ")
149 operands.push_back(
static_cast<uint32_t
>(*enclosedOpcode));
153 uint32_t
id = getValueID(operand);
154 assert(
id &&
"use before def!");
155 operands.push_back(
id);
165 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
166 auto undefType = op.getType();
167 auto &
id = undefValIDMap[undefType];
171 if (
failed(processType(op.
getLoc(), undefType, typeID)))
180 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
182 uint32_t argTypeID = 0;
183 if (
failed(processType(op.
getLoc(), arg.getType(), argTypeID))) {
186 auto argValueID = getNextID();
189 auto funcOp = cast<FunctionOpInterface>(*op);
190 for (
auto argAttr : funcOp.getArgAttrs(idx)) {
191 if (argAttr.getName() != DecorationAttr::name)
194 if (
auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
195 if (
failed(processDecorationAttr(op->
getLoc(), argValueID,
196 decAttr.getValue(), decAttr)))
201 valueIDMap[arg] = argValueID;
203 {argTypeID, argValueID});
209 LLVM_DEBUG(llvm::dbgs() <<
"-- start function '" << op.
getName() <<
"' --\n");
210 assert(functionHeader.empty() && functionBody.empty());
212 uint32_t fnTypeID = 0;
214 if (
failed(processType(op.
getLoc(), op.getFunctionType(), fnTypeID)))
219 uint32_t resTypeID = 0;
220 auto resultTypes = op.getFunctionType().
getResults();
221 if (resultTypes.size() > 1) {
222 return op.
emitError(
"cannot serialize function with multiple return types");
225 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
229 operands.push_back(resTypeID);
230 auto funcID = getOrCreateFunctionID(op.
getName());
231 operands.push_back(funcID);
232 operands.push_back(
static_cast<uint32_t
>(op.getFunctionControl()));
233 operands.push_back(fnTypeID);
242 auto linkageAttr = op.getLinkageAttributes();
243 auto hasImportLinkage =
244 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
245 spirv::LinkageType::Import);
246 if (op.isExternal() && !hasImportLinkage) {
248 "'spirv.module' cannot contain external functions "
249 "without 'Import' linkage_attributes (LinkageAttributes)");
251 if (op.isExternal() && hasImportLinkage) {
261 if (
failed(processFuncParameter(op)))
268 if (
failed(processFuncParameter(op)))
280 {getOrCreateBlockID(&op.front())});
281 if (
failed(processBlock(&op.front(),
true)))
284 &op.front(), [&](
Block *block) { return processBlock(block); },
291 for (
const auto &deferredValue : deferredPhiValues) {
292 Value value = deferredValue.first;
293 uint32_t
id = getValueID(value);
294 LLVM_DEBUG(llvm::dbgs() <<
"[phi] fix reference of value " << value
295 <<
" to id = " <<
id <<
'\n');
296 assert(
id &&
"OpPhi references undefined value!");
297 for (
size_t offset : deferredValue.second)
298 functionBody[offset] = id;
300 deferredPhiValues.clear();
302 LLVM_DEBUG(llvm::dbgs() <<
"-- completed function '" << op.
getName()
310 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
311 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
313 if (isValidDecoration != std::nullopt) {
314 if (
failed(processDecoration(op.
getLoc(), funcID, attr))) {
322 functions.append(functionHeader.begin(), functionHeader.end());
323 functions.append(functionBody.begin(), functionBody.end());
324 functionHeader.clear();
325 functionBody.clear();
330 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
333 uint32_t resultID = 0;
334 uint32_t resultTypeID = 0;
335 if (
failed(processType(op.
getLoc(), op.getType(), resultTypeID))) {
338 operands.push_back(resultTypeID);
339 resultID = getNextID();
341 operands.push_back(resultID);
342 auto attr = op->
getAttr(spirv::attributeName<spirv::StorageClass>());
345 static_cast<uint32_t
>(cast<spirv::StorageClassAttr>(attr).getValue()));
347 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
348 for (
auto arg : op.getODSOperands(0)) {
349 auto argID = getValueID(arg);
353 operands.push_back(argID);
359 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
360 return attr.getName() == elided;
364 if (
failed(processDecoration(op.
getLoc(), resultID, attr))) {
372 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
374 uint32_t resultTypeID = 0;
376 if (
failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
380 elidedAttrs.push_back(
"type");
382 operands.push_back(resultTypeID);
383 auto resultID = getNextID();
386 auto varName = varOp.getSymName();
388 if (
failed(processName(resultID, varName))) {
391 globalVarIDMap[varName] = resultID;
392 operands.push_back(resultID);
395 operands.push_back(
static_cast<uint32_t
>(varOp.storageClass()));
398 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
399 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
400 uint32_t initializerID = 0;
403 varOp->getParentOp(), initRef.getAttr());
406 if (isa<spirv::GlobalVariableOp>(initOp))
407 initializerID = getVariableID(*initSymbolName);
409 initializerID = getSpecConstID(*initSymbolName);
413 "invalid usage of undefined variable as initializer");
415 operands.push_back(initializerID);
416 elidedAttrs.push_back(initAttrName);
419 if (
failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
422 elidedAttrs.push_back(initAttrName);
425 for (
auto attr : varOp->getAttrs()) {
426 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
427 return attr.getName() == elided;
431 if (
failed(processDecoration(varOp.getLoc(), resultID, attr))) {
438 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
441 auto &body = selectionOp.getBody();
442 for (
Block &block : body)
443 getOrCreateBlockID(&block);
445 auto *headerBlock = selectionOp.getHeaderBlock();
446 auto *mergeBlock = selectionOp.getMergeBlock();
447 auto headerID = getBlockID(headerBlock);
448 auto mergeID = getBlockID(mergeBlock);
449 auto loc = selectionOp.getLoc();
461 auto emitSelectionMerge = [&]() {
462 if (
failed(emitDebugLine(functionBody, loc)))
464 lastProcessedWasMergeInst =
true;
466 functionBody, spirv::Opcode::OpSelectionMerge,
467 {mergeID,
static_cast<uint32_t
>(selectionOp.getSelectionControl())});
471 processBlock(headerBlock,
false, emitSelectionMerge)))
478 headerBlock, [&](
Block *block) {
return processBlock(block); },
479 true, {mergeBlock})))
487 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
488 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
489 LLVM_DEBUG(llvm::dbgs() <<
"\n");
493 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
497 auto &body = loopOp.getBody();
498 for (
Block &block : llvm::drop_begin(body))
499 getOrCreateBlockID(&block);
501 auto *headerBlock = loopOp.getHeaderBlock();
502 auto *continueBlock = loopOp.getContinueBlock();
503 auto *mergeBlock = loopOp.getMergeBlock();
504 auto headerID = getBlockID(headerBlock);
505 auto continueID = getBlockID(continueBlock);
506 auto mergeID = getBlockID(mergeBlock);
507 auto loc = loopOp.getLoc();
523 auto emitLoopMerge = [&]() {
524 if (
failed(emitDebugLine(functionBody, loc)))
526 lastProcessedWasMergeInst =
true;
528 functionBody, spirv::Opcode::OpLoopMerge,
529 {mergeID, continueID,
static_cast<uint32_t
>(loopOp.getLoopControl())});
532 if (
failed(processBlock(headerBlock,
false, emitLoopMerge)))
539 headerBlock, [&](
Block *block) {
return processBlock(block); },
540 true, {continueBlock, mergeBlock})))
544 if (
failed(processBlock(continueBlock)))
552 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
553 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
554 LLVM_DEBUG(llvm::dbgs() <<
"\n");
559 spirv::BranchConditionalOp condBranchOp) {
560 auto conditionID = getValueID(condBranchOp.getCondition());
561 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
562 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
565 if (
auto weights = condBranchOp.getBranchWeights()) {
566 for (
auto val : weights->getValue())
567 arguments.push_back(cast<IntegerAttr>(val).getInt());
570 if (
failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
577 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
578 if (
failed(emitDebugLine(functionBody, branchOp.getLoc())))
581 {getOrCreateBlockID(branchOp.getTarget())});
585 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
586 auto varName = addressOfOp.getVariable();
587 auto variableID = getVariableID(varName);
589 return addressOfOp.emitError(
"unknown result <id> for variable ")
592 valueIDMap[addressOfOp.getPointer()] = variableID;
597 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
598 auto constName = referenceOfOp.getSpecConst();
599 auto constID = getSpecConstID(constName);
601 return referenceOfOp.emitError(
602 "unknown result <id> for specialization constant ")
605 valueIDMap[referenceOfOp.getReference()] = constID;
611 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
614 operands.push_back(
static_cast<uint32_t
>(op.getExecutionModel()));
616 auto funcID = getFunctionID(op.getFn());
618 return op.
emitError(
"missing <id> for function ")
620 <<
"; function needs to be defined before spirv.EntryPoint is "
623 operands.push_back(funcID);
628 if (
auto interface = op.getInterface()) {
629 for (
auto var : interface.getValue()) {
630 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
633 "referencing undefined global variable."
634 "spirv.EntryPoint is at the end of spirv.module. All "
635 "referenced variables should already be defined");
637 operands.push_back(
id);
646 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
649 auto funcID = getFunctionID(op.getFn());
651 return op.
emitError(
"missing <id> for function ")
653 <<
"; function needs to be serialized before ExecutionModeOp is "
656 operands.push_back(funcID);
658 operands.push_back(
static_cast<uint32_t
>(op.getExecutionMode()));
661 auto values = op.getValues();
663 for (
auto &intVal : values.getValue()) {
664 operands.push_back(
static_cast<uint32_t
>(
665 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
675 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
676 auto funcName = op.getCallee();
677 uint32_t resTypeID = 0;
680 if (
failed(processType(op.
getLoc(), resultTy, resTypeID)))
683 auto funcID = getOrCreateFunctionID(funcName);
684 auto funcCallID = getNextID();
687 for (
auto value : op.getArguments()) {
688 auto valueID = getValueID(value);
689 assert(valueID &&
"cannot find a value for spirv.FunctionCall");
690 operands.push_back(valueID);
693 if (!isa<NoneType>(resultTy))
694 valueIDMap[op.
getResult(0)] = funcCallID;
702 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
707 auto id = getValueID(operand);
708 assert(
id &&
"use before def!");
709 operands.push_back(
id);
712 StringAttr memoryAccess = op.getMemoryAccessAttrName();
713 if (
auto attr = op->
getAttr(memoryAccess)) {
715 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
718 elidedAttrs.push_back(memoryAccess.strref());
720 StringAttr alignment = op.getAlignmentAttrName();
721 if (
auto attr = op->
getAttr(alignment)) {
722 operands.push_back(
static_cast<uint32_t
>(
723 cast<IntegerAttr>(attr).getValue().getZExtValue()));
726 elidedAttrs.push_back(alignment.strref());
728 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
729 if (
auto attr = op->
getAttr(sourceMemoryAccess)) {
731 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
734 elidedAttrs.push_back(sourceMemoryAccess.strref());
736 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
737 if (
auto attr = op->
getAttr(sourceAlignment)) {
738 operands.push_back(
static_cast<uint32_t
>(
739 cast<IntegerAttr>(attr).getValue().getZExtValue()));
742 elidedAttrs.push_back(sourceAlignment.strref());
751 spirv::GenericCastToPtrExplicitOp op) {
755 uint32_t resultTypeID = 0;
756 uint32_t resultID = 0;
758 if (
failed(processType(loc, resultTy, resultTypeID)))
760 operands.push_back(resultTypeID);
762 resultID = getNextID();
763 operands.push_back(resultID);
767 operands.push_back(getValueID(operand));
768 spirv::StorageClass resultStorage =
769 cast<spirv::PointerType>(resultTy).getStorageClass();
770 operands.push_back(
static_cast<uint32_t
>(resultStorage));
778 #define GET_SERIALIZATION_FNS
779 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
OpListType & getOperations()
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
BlockListType & getBlocks()
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.