19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "spirv-serialization"
45 bool skipHeader =
false,
BlockRange skipBlocks = {}) {
46 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
49 for (
Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50 if (skipHeader && block == headerBlock)
52 if (failed(blockHandler(block)))
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
62 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63 valueIDMap[op.getResult()] = resultID;
69 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
70 if (
auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
73 if (
auto specID = op->getAttrOfType<IntegerAttr>(
"spec_id")) {
74 auto val =
static_cast<uint32_t
>(specID.getInt());
75 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
79 specConstIDMap[op.getSymName()] = resultID;
80 return processName(resultID, op.getSymName());
86 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
88 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
92 auto resultID = getNextID();
95 operands.push_back(typeID);
96 operands.push_back(resultID);
98 auto constituents = op.getConstituents();
100 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
103 auto constituentName = constituent.getValue();
104 auto constituentID = getSpecConstID(constituentName);
106 if (!constituentID) {
107 return op.emitError(
"unknown result <id> for specialization constant ")
111 operands.push_back(constituentID);
115 spirv::Opcode::OpSpecConstantComposite, operands);
116 specConstIDMap[op.getSymName()] = resultID;
118 return processName(resultID, op.getSymName());
122 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
124 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
128 auto resultID = getNextID();
131 operands.push_back(typeID);
132 operands.push_back(resultID);
134 Block &block = op.getRegion().getBlocks().
front();
137 std::string enclosedOpName;
138 llvm::raw_string_ostream rss(enclosedOpName);
140 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
142 if (!enclosedOpcode) {
143 op.emitError(
"Couldn't find op code for op ")
148 operands.push_back(
static_cast<uint32_t
>(*enclosedOpcode));
152 uint32_t
id = getValueID(operand);
153 assert(
id &&
"use before def!");
154 operands.push_back(
id);
159 valueIDMap[op.getResult()] = resultID;
164 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
165 auto undefType = op.getType();
166 auto &
id = undefValIDMap[undefType];
170 if (failed(processType(op.getLoc(), undefType, typeID)))
175 valueIDMap[op.getResult()] = id;
179 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
181 uint32_t argTypeID = 0;
182 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
185 auto argValueID = getNextID();
188 auto funcOp = cast<FunctionOpInterface>(*op);
189 for (
auto argAttr : funcOp.getArgAttrs(idx)) {
190 if (argAttr.getName() != DecorationAttr::name)
193 if (
auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
194 if (failed(processDecorationAttr(op->getLoc(), argValueID,
195 decAttr.getValue(), decAttr)))
200 valueIDMap[arg] = argValueID;
202 {argTypeID, argValueID});
207 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
208 LLVM_DEBUG(llvm::dbgs() <<
"-- start function '" << op.getName() <<
"' --\n");
209 assert(functionHeader.empty() && functionBody.empty());
211 uint32_t fnTypeID = 0;
213 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
218 uint32_t resTypeID = 0;
219 auto resultTypes = op.getFunctionType().getResults();
220 if (resultTypes.size() > 1) {
221 return op.emitError(
"cannot serialize function with multiple return types");
223 if (failed(processType(op.getLoc(),
224 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
228 operands.push_back(resTypeID);
229 auto funcID = getOrCreateFunctionID(op.getName());
230 operands.push_back(funcID);
231 operands.push_back(
static_cast<uint32_t
>(op.getFunctionControl()));
232 operands.push_back(fnTypeID);
236 if (failed(processName(funcID, op.getName()))) {
241 auto linkageAttr = op.getLinkageAttributes();
242 auto hasImportLinkage =
243 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
244 spirv::LinkageType::Import);
245 if (op.isExternal() && !hasImportLinkage) {
247 "'spirv.module' cannot contain external functions "
248 "without 'Import' linkage_attributes (LinkageAttributes)");
250 if (op.isExternal() && hasImportLinkage) {
260 if (failed(processFuncParameter(op)))
267 if (failed(processFuncParameter(op)))
279 {getOrCreateBlockID(&op.front())});
280 if (failed(processBlock(&op.front(),
true)))
283 &op.front(), [&](
Block *block) { return processBlock(block); },
290 for (
const auto &deferredValue : deferredPhiValues) {
291 Value value = deferredValue.first;
292 uint32_t
id = getValueID(value);
293 LLVM_DEBUG(llvm::dbgs() <<
"[phi] fix reference of value " << value
294 <<
" to id = " <<
id <<
'\n');
295 assert(
id &&
"OpPhi references undefined value!");
296 for (
size_t offset : deferredValue.second)
297 functionBody[offset] = id;
299 deferredPhiValues.clear();
301 LLVM_DEBUG(llvm::dbgs() <<
"-- completed function '" << op.getName()
307 for (
auto attr : op->getAttrs()) {
309 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
310 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
312 if (isValidDecoration != std::nullopt) {
313 if (failed(processDecoration(op.getLoc(), funcID, attr))) {
321 functions.append(functionHeader.begin(), functionHeader.end());
322 functions.append(functionBody.begin(), functionBody.end());
323 functionHeader.clear();
324 functionBody.clear();
329 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
332 uint32_t resultID = 0;
333 uint32_t resultTypeID = 0;
334 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
337 operands.push_back(resultTypeID);
338 resultID = getNextID();
339 valueIDMap[op.getResult()] = resultID;
340 operands.push_back(resultID);
341 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
344 static_cast<uint32_t
>(cast<spirv::StorageClassAttr>(attr).getValue()));
346 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
347 for (
auto arg : op.getODSOperands(0)) {
348 auto argID = getValueID(arg);
350 return emitError(op.getLoc(),
"operand 0 has a use before def");
352 operands.push_back(argID);
354 if (failed(emitDebugLine(functionHeader, op.getLoc())))
357 for (
auto attr : op->getAttrs()) {
358 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
359 return attr.getName() == elided;
363 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
371 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
373 uint32_t resultTypeID = 0;
375 if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
379 elidedAttrs.push_back(
"type");
381 operands.push_back(resultTypeID);
382 auto resultID = getNextID();
385 auto varName = varOp.getSymName();
387 if (failed(processName(resultID, varName))) {
390 globalVarIDMap[varName] = resultID;
391 operands.push_back(resultID);
394 operands.push_back(
static_cast<uint32_t
>(varOp.storageClass()));
397 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
398 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
399 uint32_t initializerID = 0;
402 varOp->getParentOp(), initRef.getAttr());
405 if (isa<spirv::GlobalVariableOp>(initOp))
406 initializerID = getVariableID(*initSymbolName);
408 initializerID = getSpecConstID(*initSymbolName);
412 "invalid usage of undefined variable as initializer");
414 operands.push_back(initializerID);
415 elidedAttrs.push_back(initAttrName);
418 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
421 elidedAttrs.push_back(initAttrName);
424 for (
auto attr : varOp->getAttrs()) {
425 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
426 return attr.getName() == elided;
430 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
437 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
440 auto &body = selectionOp.getBody();
441 for (
Block &block : body)
442 getOrCreateBlockID(&block);
444 auto *headerBlock = selectionOp.getHeaderBlock();
445 auto *mergeBlock = selectionOp.getMergeBlock();
446 auto headerID = getBlockID(headerBlock);
447 auto mergeID = getBlockID(mergeBlock);
448 auto loc = selectionOp.getLoc();
454 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
455 assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
456 for (
unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
457 selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
469 auto emitSelectionMerge = [&]() {
470 if (failed(emitDebugLine(functionBody, loc)))
472 lastProcessedWasMergeInst =
true;
474 functionBody, spirv::Opcode::OpSelectionMerge,
475 {mergeID,
static_cast<uint32_t
>(selectionOp.getSelectionControl())});
479 processBlock(headerBlock,
false, emitSelectionMerge)))
486 headerBlock, [&](
Block *block) {
return processBlock(block); },
487 true, {mergeBlock})))
498 if (failed(emitPhiForBlockArguments(mergeBlock)))
501 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
502 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
503 LLVM_DEBUG(llvm::dbgs() <<
"\n");
507 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
511 auto &body = loopOp.getBody();
512 for (
Block &block : llvm::drop_begin(body))
513 getOrCreateBlockID(&block);
515 auto *headerBlock = loopOp.getHeaderBlock();
516 auto *continueBlock = loopOp.getContinueBlock();
517 auto *mergeBlock = loopOp.getMergeBlock();
518 auto headerID = getBlockID(headerBlock);
519 auto continueID = getBlockID(continueBlock);
520 auto mergeID = getBlockID(mergeBlock);
521 auto loc = loopOp.getLoc();
525 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
526 assert(loopOp.getNumResults() == mergeOp.getNumOperands());
527 for (
unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
528 loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
544 auto emitLoopMerge = [&]() {
545 if (failed(emitDebugLine(functionBody, loc)))
547 lastProcessedWasMergeInst =
true;
549 functionBody, spirv::Opcode::OpLoopMerge,
550 {mergeID, continueID,
static_cast<uint32_t
>(loopOp.getLoopControl())});
553 if (failed(processBlock(headerBlock,
false, emitLoopMerge)))
560 headerBlock, [&](
Block *block) {
return processBlock(block); },
561 true, {continueBlock, mergeBlock})))
565 if (failed(processBlock(continueBlock)))
573 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
574 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
575 LLVM_DEBUG(llvm::dbgs() <<
"\n");
579 LogicalResult Serializer::processBranchConditionalOp(
580 spirv::BranchConditionalOp condBranchOp) {
581 auto conditionID = getValueID(condBranchOp.getCondition());
582 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
583 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
586 if (
auto weights = condBranchOp.getBranchWeights()) {
587 for (
auto val : weights->getValue())
588 arguments.push_back(cast<IntegerAttr>(val).getInt());
591 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
598 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
599 if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
602 {getOrCreateBlockID(branchOp.getTarget())});
606 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
607 auto varName = addressOfOp.getVariable();
608 auto variableID = getVariableID(varName);
610 return addressOfOp.emitError(
"unknown result <id> for variable ")
613 valueIDMap[addressOfOp.getPointer()] = variableID;
618 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
619 auto constName = referenceOfOp.getSpecConst();
620 auto constID = getSpecConstID(constName);
622 return referenceOfOp.emitError(
623 "unknown result <id> for specialization constant ")
626 valueIDMap[referenceOfOp.getReference()] = constID;
632 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
635 operands.push_back(
static_cast<uint32_t
>(op.getExecutionModel()));
637 auto funcID = getFunctionID(op.getFn());
639 return op.emitError(
"missing <id> for function ")
641 <<
"; function needs to be defined before spirv.EntryPoint is "
644 operands.push_back(funcID);
649 if (
auto interface = op.getInterface()) {
650 for (
auto var : interface.getValue()) {
651 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
654 "referencing undefined global variable."
655 "spirv.EntryPoint is at the end of spirv.module. All "
656 "referenced variables should already be defined");
658 operands.push_back(
id);
667 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
670 auto funcID = getFunctionID(op.getFn());
672 return op.emitError(
"missing <id> for function ")
674 <<
"; function needs to be serialized before ExecutionModeOp is "
677 operands.push_back(funcID);
679 operands.push_back(
static_cast<uint32_t
>(op.getExecutionMode()));
682 auto values = op.getValues();
684 for (
auto &intVal : values.getValue()) {
685 operands.push_back(
static_cast<uint32_t
>(
686 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
696 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
697 auto funcName = op.getCallee();
698 uint32_t resTypeID = 0;
700 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
701 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
704 auto funcID = getOrCreateFunctionID(funcName);
705 auto funcCallID = getNextID();
708 for (
auto value : op.getArguments()) {
709 auto valueID = getValueID(value);
710 assert(valueID &&
"cannot find a value for spirv.FunctionCall");
711 operands.push_back(valueID);
714 if (!isa<NoneType>(resultTy))
715 valueIDMap[op.getResult(0)] = funcCallID;
723 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
727 for (
Value operand : op->getOperands()) {
728 auto id = getValueID(operand);
729 assert(
id &&
"use before def!");
730 operands.push_back(
id);
733 StringAttr memoryAccess = op.getMemoryAccessAttrName();
734 if (
auto attr = op->getAttr(memoryAccess)) {
736 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
739 elidedAttrs.push_back(memoryAccess.strref());
741 StringAttr alignment = op.getAlignmentAttrName();
742 if (
auto attr = op->getAttr(alignment)) {
743 operands.push_back(
static_cast<uint32_t
>(
744 cast<IntegerAttr>(attr).getValue().getZExtValue()));
747 elidedAttrs.push_back(alignment.strref());
749 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
750 if (
auto attr = op->getAttr(sourceMemoryAccess)) {
752 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
755 elidedAttrs.push_back(sourceMemoryAccess.strref());
757 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
758 if (
auto attr = op->getAttr(sourceAlignment)) {
759 operands.push_back(
static_cast<uint32_t
>(
760 cast<IntegerAttr>(attr).getValue().getZExtValue()));
763 elidedAttrs.push_back(sourceAlignment.strref());
764 if (failed(emitDebugLine(functionBody, op.getLoc())))
771 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
772 spirv::GenericCastToPtrExplicitOp op) {
776 uint32_t resultTypeID = 0;
777 uint32_t resultID = 0;
778 resultTy = op->getResult(0).getType();
779 if (failed(processType(loc, resultTy, resultTypeID)))
781 operands.push_back(resultTypeID);
783 resultID = getNextID();
784 operands.push_back(resultID);
785 valueIDMap[op->getResult(0)] = resultID;
787 for (
Value operand : op->getOperands())
788 operands.push_back(getValueID(operand));
789 spirv::StorageClass resultStorage =
790 cast<spirv::PointerType>(resultTy).getStorageClass();
791 operands.push_back(
static_cast<uint32_t
>(resultStorage));
799 #define GET_SERIALIZATION_FNS
800 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
OpListType & getOperations()
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.