19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "spirv-serialization"
45 bool skipHeader =
false,
BlockRange skipBlocks = {}) {
46 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
49 for (
Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50 if (skipHeader && block == headerBlock)
52 if (failed(blockHandler(block)))
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
62 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63 valueIDMap[op.getResult()] = resultID;
69 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
70 if (
auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
73 if (
auto specID = op->getAttrOfType<IntegerAttr>(
"spec_id")) {
74 auto val =
static_cast<uint32_t
>(specID.getInt());
75 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
79 specConstIDMap[op.getSymName()] = resultID;
80 return processName(resultID, op.getSymName());
86 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
88 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
92 auto resultID = getNextID();
95 operands.push_back(typeID);
96 operands.push_back(resultID);
98 auto constituents = op.getConstituents();
100 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
103 auto constituentName = constituent.getValue();
104 auto constituentID = getSpecConstID(constituentName);
106 if (!constituentID) {
107 return op.emitError(
"unknown result <id> for specialization constant ")
111 operands.push_back(constituentID);
115 spirv::Opcode::OpSpecConstantComposite, operands);
116 specConstIDMap[op.getSymName()] = resultID;
118 return processName(resultID, op.getSymName());
122 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
124 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
128 auto resultID = getNextID();
131 operands.push_back(typeID);
132 operands.push_back(resultID);
134 Block &block = op.getRegion().getBlocks().
front();
137 std::string enclosedOpName;
138 llvm::raw_string_ostream rss(enclosedOpName);
140 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
142 if (!enclosedOpcode) {
143 op.emitError(
"Couldn't find op code for op ")
148 operands.push_back(
static_cast<uint32_t
>(*enclosedOpcode));
152 uint32_t
id = getValueID(operand);
153 assert(
id &&
"use before def!");
154 operands.push_back(
id);
159 valueIDMap[op.getResult()] = resultID;
164 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
165 auto undefType = op.getType();
166 auto &
id = undefValIDMap[undefType];
170 if (failed(processType(op.getLoc(), undefType, typeID)))
175 valueIDMap[op.getResult()] = id;
179 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
181 uint32_t argTypeID = 0;
182 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
185 auto argValueID = getNextID();
188 auto funcOp = cast<FunctionOpInterface>(*op);
189 for (
auto argAttr : funcOp.getArgAttrs(idx)) {
190 if (argAttr.getName() != DecorationAttr::name)
193 if (
auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
194 if (failed(processDecorationAttr(op->getLoc(), argValueID,
195 decAttr.getValue(), decAttr)))
200 valueIDMap[arg] = argValueID;
202 {argTypeID, argValueID});
207 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
208 LLVM_DEBUG(llvm::dbgs() <<
"-- start function '" << op.getName() <<
"' --\n");
209 assert(functionHeader.empty() && functionBody.empty());
211 uint32_t fnTypeID = 0;
213 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
218 uint32_t resTypeID = 0;
219 auto resultTypes = op.getFunctionType().getResults();
220 if (resultTypes.size() > 1) {
221 return op.emitError(
"cannot serialize function with multiple return types");
223 if (failed(processType(op.getLoc(),
224 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
228 operands.push_back(resTypeID);
229 auto funcID = getOrCreateFunctionID(op.getName());
230 operands.push_back(funcID);
231 operands.push_back(
static_cast<uint32_t
>(op.getFunctionControl()));
232 operands.push_back(fnTypeID);
236 if (failed(processName(funcID, op.getName()))) {
241 auto linkageAttr = op.getLinkageAttributes();
242 auto hasImportLinkage =
243 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
244 spirv::LinkageType::Import);
245 if (op.isExternal() && !hasImportLinkage) {
247 "'spirv.module' cannot contain external functions "
248 "without 'Import' linkage_attributes (LinkageAttributes)");
250 if (op.isExternal() && hasImportLinkage) {
260 if (failed(processFuncParameter(op)))
267 if (failed(processFuncParameter(op)))
279 {getOrCreateBlockID(&op.front())});
280 if (failed(processBlock(&op.front(),
true)))
283 &op.front(), [&](
Block *block) { return processBlock(block); },
290 for (
const auto &deferredValue : deferredPhiValues) {
291 Value value = deferredValue.first;
292 uint32_t
id = getValueID(value);
293 LLVM_DEBUG(llvm::dbgs() <<
"[phi] fix reference of value " << value
294 <<
" to id = " <<
id <<
'\n');
295 assert(
id &&
"OpPhi references undefined value!");
296 for (
size_t offset : deferredValue.second)
297 functionBody[offset] = id;
299 deferredPhiValues.clear();
301 LLVM_DEBUG(llvm::dbgs() <<
"-- completed function '" << op.getName()
307 for (
auto attr : op->getAttrs()) {
309 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
310 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
312 if (isValidDecoration != std::nullopt) {
313 if (failed(processDecoration(op.getLoc(), funcID, attr))) {
321 functions.append(functionHeader.begin(), functionHeader.end());
322 functions.append(functionBody.begin(), functionBody.end());
323 functionHeader.clear();
324 functionBody.clear();
329 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
332 uint32_t resultID = 0;
333 uint32_t resultTypeID = 0;
334 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
337 operands.push_back(resultTypeID);
338 resultID = getNextID();
339 valueIDMap[op.getResult()] = resultID;
340 operands.push_back(resultID);
341 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
344 static_cast<uint32_t
>(cast<spirv::StorageClassAttr>(attr).getValue()));
346 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
347 for (
auto arg : op.getODSOperands(0)) {
348 auto argID = getValueID(arg);
350 return emitError(op.getLoc(),
"operand 0 has a use before def");
352 operands.push_back(argID);
354 if (failed(emitDebugLine(functionHeader, op.getLoc())))
357 for (
auto attr : op->getAttrs()) {
358 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
359 return attr.getName() == elided;
363 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
371 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
373 uint32_t resultTypeID = 0;
375 if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
379 elidedAttrs.push_back(
"type");
381 operands.push_back(resultTypeID);
382 auto resultID = getNextID();
385 auto varName = varOp.getSymName();
387 if (failed(processName(resultID, varName))) {
390 globalVarIDMap[varName] = resultID;
391 operands.push_back(resultID);
394 operands.push_back(
static_cast<uint32_t
>(varOp.storageClass()));
397 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
398 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
399 uint32_t initializerID = 0;
402 varOp->getParentOp(), initRef.getAttr());
405 if (isa<spirv::GlobalVariableOp>(initOp))
406 initializerID = getVariableID(*initSymbolName);
408 initializerID = getSpecConstID(*initSymbolName);
412 "invalid usage of undefined variable as initializer");
414 operands.push_back(initializerID);
415 elidedAttrs.push_back(initAttrName);
418 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
421 elidedAttrs.push_back(initAttrName);
424 for (
auto attr : varOp->getAttrs()) {
425 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
426 return attr.getName() == elided;
430 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
437 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
440 auto &body = selectionOp.getBody();
441 for (
Block &block : body)
442 getOrCreateBlockID(&block);
444 auto *headerBlock = selectionOp.getHeaderBlock();
445 auto *mergeBlock = selectionOp.getMergeBlock();
446 auto headerID = getBlockID(headerBlock);
447 auto mergeID = getBlockID(mergeBlock);
448 auto loc = selectionOp.getLoc();
460 auto emitSelectionMerge = [&]() {
461 if (failed(emitDebugLine(functionBody, loc)))
463 lastProcessedWasMergeInst =
true;
465 functionBody, spirv::Opcode::OpSelectionMerge,
466 {mergeID,
static_cast<uint32_t
>(selectionOp.getSelectionControl())});
470 processBlock(headerBlock,
false, emitSelectionMerge)))
477 headerBlock, [&](
Block *block) {
return processBlock(block); },
478 true, {mergeBlock})))
486 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
487 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
488 LLVM_DEBUG(llvm::dbgs() <<
"\n");
492 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
496 auto &body = loopOp.getBody();
497 for (
Block &block : llvm::drop_begin(body))
498 getOrCreateBlockID(&block);
500 auto *headerBlock = loopOp.getHeaderBlock();
501 auto *continueBlock = loopOp.getContinueBlock();
502 auto *mergeBlock = loopOp.getMergeBlock();
503 auto headerID = getBlockID(headerBlock);
504 auto continueID = getBlockID(continueBlock);
505 auto mergeID = getBlockID(mergeBlock);
506 auto loc = loopOp.getLoc();
522 auto emitLoopMerge = [&]() {
523 if (failed(emitDebugLine(functionBody, loc)))
525 lastProcessedWasMergeInst =
true;
527 functionBody, spirv::Opcode::OpLoopMerge,
528 {mergeID, continueID,
static_cast<uint32_t
>(loopOp.getLoopControl())});
531 if (failed(processBlock(headerBlock,
false, emitLoopMerge)))
538 headerBlock, [&](
Block *block) {
return processBlock(block); },
539 true, {continueBlock, mergeBlock})))
543 if (failed(processBlock(continueBlock)))
551 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
552 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
553 LLVM_DEBUG(llvm::dbgs() <<
"\n");
557 LogicalResult Serializer::processBranchConditionalOp(
558 spirv::BranchConditionalOp condBranchOp) {
559 auto conditionID = getValueID(condBranchOp.getCondition());
560 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
561 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
564 if (
auto weights = condBranchOp.getBranchWeights()) {
565 for (
auto val : weights->getValue())
566 arguments.push_back(cast<IntegerAttr>(val).getInt());
569 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
576 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
577 if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
580 {getOrCreateBlockID(branchOp.getTarget())});
584 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
585 auto varName = addressOfOp.getVariable();
586 auto variableID = getVariableID(varName);
588 return addressOfOp.emitError(
"unknown result <id> for variable ")
591 valueIDMap[addressOfOp.getPointer()] = variableID;
596 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
597 auto constName = referenceOfOp.getSpecConst();
598 auto constID = getSpecConstID(constName);
600 return referenceOfOp.emitError(
601 "unknown result <id> for specialization constant ")
604 valueIDMap[referenceOfOp.getReference()] = constID;
610 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
613 operands.push_back(
static_cast<uint32_t
>(op.getExecutionModel()));
615 auto funcID = getFunctionID(op.getFn());
617 return op.emitError(
"missing <id> for function ")
619 <<
"; function needs to be defined before spirv.EntryPoint is "
622 operands.push_back(funcID);
627 if (
auto interface = op.getInterface()) {
628 for (
auto var : interface.getValue()) {
629 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
632 "referencing undefined global variable."
633 "spirv.EntryPoint is at the end of spirv.module. All "
634 "referenced variables should already be defined");
636 operands.push_back(
id);
645 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
648 auto funcID = getFunctionID(op.getFn());
650 return op.emitError(
"missing <id> for function ")
652 <<
"; function needs to be serialized before ExecutionModeOp is "
655 operands.push_back(funcID);
657 operands.push_back(
static_cast<uint32_t
>(op.getExecutionMode()));
660 auto values = op.getValues();
662 for (
auto &intVal : values.getValue()) {
663 operands.push_back(
static_cast<uint32_t
>(
664 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
674 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
675 auto funcName = op.getCallee();
676 uint32_t resTypeID = 0;
678 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
679 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
682 auto funcID = getOrCreateFunctionID(funcName);
683 auto funcCallID = getNextID();
686 for (
auto value : op.getArguments()) {
687 auto valueID = getValueID(value);
688 assert(valueID &&
"cannot find a value for spirv.FunctionCall");
689 operands.push_back(valueID);
692 if (!isa<NoneType>(resultTy))
693 valueIDMap[op.getResult(0)] = funcCallID;
701 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
705 for (
Value operand : op->getOperands()) {
706 auto id = getValueID(operand);
707 assert(
id &&
"use before def!");
708 operands.push_back(
id);
711 StringAttr memoryAccess = op.getMemoryAccessAttrName();
712 if (
auto attr = op->getAttr(memoryAccess)) {
714 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
717 elidedAttrs.push_back(memoryAccess.strref());
719 StringAttr alignment = op.getAlignmentAttrName();
720 if (
auto attr = op->getAttr(alignment)) {
721 operands.push_back(
static_cast<uint32_t
>(
722 cast<IntegerAttr>(attr).getValue().getZExtValue()));
725 elidedAttrs.push_back(alignment.strref());
727 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
728 if (
auto attr = op->getAttr(sourceMemoryAccess)) {
730 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
733 elidedAttrs.push_back(sourceMemoryAccess.strref());
735 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
736 if (
auto attr = op->getAttr(sourceAlignment)) {
737 operands.push_back(
static_cast<uint32_t
>(
738 cast<IntegerAttr>(attr).getValue().getZExtValue()));
741 elidedAttrs.push_back(sourceAlignment.strref());
742 if (failed(emitDebugLine(functionBody, op.getLoc())))
749 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
750 spirv::GenericCastToPtrExplicitOp op) {
754 uint32_t resultTypeID = 0;
755 uint32_t resultID = 0;
756 resultTy = op->getResult(0).getType();
757 if (failed(processType(loc, resultTy, resultTypeID)))
759 operands.push_back(resultTypeID);
761 resultID = getNextID();
762 operands.push_back(resultID);
763 valueIDMap[op->getResult(0)] = resultID;
765 for (
Value operand : op->getOperands())
766 operands.push_back(getValueID(operand));
767 spirv::StorageClass resultStorage =
768 cast<spirv::PointerType>(resultTy).getStorageClass();
769 operands.push_back(
static_cast<uint32_t
>(resultStorage));
777 #define GET_SERIALIZATION_FNS
778 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
OpListType & getOperations()
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.