19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "spirv-serialization"
45 bool skipHeader =
false,
BlockRange skipBlocks = {}) {
46 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
49 for (
Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50 if (skipHeader && block == headerBlock)
52 if (
failed(blockHandler(block)))
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
62 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63 valueIDMap[op.getResult()] = resultID;
69 LogicalResult Serializer::processConstantCompositeReplicateOp(
70 spirv::EXTConstantCompositeReplicateOp op) {
71 if (uint32_t resultID = prepareConstantCompositeReplicate(
72 op.getLoc(), op.getType(), op.getValue())) {
73 valueIDMap[op.getResult()] = resultID;
79 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
80 if (
auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
83 if (
auto specID = op->getAttrOfType<IntegerAttr>(
"spec_id")) {
84 auto val =
static_cast<uint32_t
>(specID.getInt());
85 if (
failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
89 specConstIDMap[op.getSymName()] = resultID;
90 return processName(resultID, op.getSymName());
96 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
98 if (
failed(processType(op.getLoc(), op.getType(), typeID))) {
102 auto resultID = getNextID();
105 operands.push_back(typeID);
106 operands.push_back(resultID);
108 auto constituents = op.getConstituents();
110 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
111 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
113 auto constituentName = constituent.getValue();
114 auto constituentID = getSpecConstID(constituentName);
116 if (!constituentID) {
117 return op.emitError(
"unknown result <id> for specialization constant ")
121 operands.push_back(constituentID);
125 spirv::Opcode::OpSpecConstantComposite, operands);
126 specConstIDMap[op.getSymName()] = resultID;
128 return processName(resultID, op.getSymName());
131 LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
132 spirv::EXTSpecConstantCompositeReplicateOp op) {
134 if (
failed(processType(op.getLoc(), op.getType(), typeID))) {
138 auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
141 "expected flat symbol reference for constituent instead of ")
142 << op.getConstituent();
144 StringRef constituentName = constituent.getValue();
145 uint32_t constituentID = getSpecConstID(constituentName);
146 if (!constituentID) {
147 return op.emitError(
"unknown result <id> for replicated spec constant ")
151 uint32_t resultID = getNextID();
152 uint32_t operands[] = {typeID, resultID, constituentID};
155 spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
158 specConstIDMap[op.getSymName()] = resultID;
160 return processName(resultID, op.getSymName());
164 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
166 if (
failed(processType(op.getLoc(), op.getType(), typeID))) {
170 auto resultID = getNextID();
173 operands.push_back(typeID);
174 operands.push_back(resultID);
176 Block &block = op.getRegion().getBlocks().
front();
179 std::string enclosedOpName;
180 llvm::raw_string_ostream rss(enclosedOpName);
182 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
184 if (!enclosedOpcode) {
185 op.emitError(
"Couldn't find op code for op ")
190 operands.push_back(
static_cast<uint32_t
>(*enclosedOpcode));
194 uint32_t
id = getValueID(operand);
195 assert(
id &&
"use before def!");
196 operands.push_back(
id);
201 valueIDMap[op.getResult()] = resultID;
206 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
207 auto undefType = op.getType();
208 auto &
id = undefValIDMap[undefType];
212 if (
failed(processType(op.getLoc(), undefType, typeID)))
217 valueIDMap[op.getResult()] = id;
221 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
223 uint32_t argTypeID = 0;
224 if (
failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
227 auto argValueID = getNextID();
230 auto funcOp = cast<FunctionOpInterface>(*op);
231 for (
auto argAttr : funcOp.getArgAttrs(idx)) {
232 if (argAttr.getName() != DecorationAttr::name)
235 if (
auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
236 if (
failed(processDecorationAttr(op->getLoc(), argValueID,
237 decAttr.getValue(), decAttr)))
242 valueIDMap[arg] = argValueID;
244 {argTypeID, argValueID});
249 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
250 LLVM_DEBUG(llvm::dbgs() <<
"-- start function '" << op.getName() <<
"' --\n");
251 assert(functionHeader.empty() && functionBody.empty());
253 uint32_t fnTypeID = 0;
255 if (
failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
260 uint32_t resTypeID = 0;
261 auto resultTypes = op.getFunctionType().getResults();
262 if (resultTypes.size() > 1) {
263 return op.emitError(
"cannot serialize function with multiple return types");
265 if (
failed(processType(op.getLoc(),
266 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
270 operands.push_back(resTypeID);
271 auto funcID = getOrCreateFunctionID(op.getName());
272 operands.push_back(funcID);
273 operands.push_back(
static_cast<uint32_t
>(op.getFunctionControl()));
274 operands.push_back(fnTypeID);
278 if (
failed(processName(funcID, op.getName()))) {
283 auto linkageAttr = op.getLinkageAttributes();
284 auto hasImportLinkage =
285 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
286 spirv::LinkageType::Import);
287 if (op.isExternal() && !hasImportLinkage) {
289 "'spirv.module' cannot contain external functions "
290 "without 'Import' linkage_attributes (LinkageAttributes)");
292 if (op.isExternal() && hasImportLinkage) {
302 if (
failed(processFuncParameter(op)))
309 if (
failed(processFuncParameter(op)))
321 {getOrCreateBlockID(&op.front())});
322 if (
failed(processBlock(&op.front(),
true)))
325 &op.front(), [&](
Block *block) { return processBlock(block); },
332 for (
const auto &deferredValue : deferredPhiValues) {
333 Value value = deferredValue.first;
334 uint32_t
id = getValueID(value);
335 LLVM_DEBUG(llvm::dbgs() <<
"[phi] fix reference of value " << value
336 <<
" to id = " <<
id <<
'\n');
337 assert(
id &&
"OpPhi references undefined value!");
338 for (
size_t offset : deferredValue.second)
339 functionBody[offset] = id;
341 deferredPhiValues.clear();
343 LLVM_DEBUG(llvm::dbgs() <<
"-- completed function '" << op.getName()
349 for (
auto attr : op->getAttrs()) {
351 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
352 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
354 if (isValidDecoration != std::nullopt) {
355 if (
failed(processDecoration(op.getLoc(), funcID, attr))) {
363 functions.append(functionHeader.begin(), functionHeader.end());
364 functions.append(functionBody.begin(), functionBody.end());
365 functionHeader.clear();
366 functionBody.clear();
371 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
374 uint32_t resultID = 0;
375 uint32_t resultTypeID = 0;
376 if (
failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
379 operands.push_back(resultTypeID);
380 resultID = getNextID();
381 valueIDMap[op.getResult()] = resultID;
382 operands.push_back(resultID);
383 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
386 static_cast<uint32_t
>(cast<spirv::StorageClassAttr>(attr).getValue()));
388 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
389 for (
auto arg : op.getODSOperands(0)) {
390 auto argID = getValueID(arg);
392 return emitError(op.getLoc(),
"operand 0 has a use before def");
394 operands.push_back(argID);
396 if (
failed(emitDebugLine(functionHeader, op.getLoc())))
399 for (
auto attr : op->getAttrs()) {
400 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
401 return attr.getName() == elided;
405 if (
failed(processDecoration(op.getLoc(), resultID, attr))) {
413 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
415 uint32_t resultTypeID = 0;
417 if (
failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
421 elidedAttrs.push_back(
"type");
423 operands.push_back(resultTypeID);
424 auto resultID = getNextID();
427 auto varName = varOp.getSymName();
429 if (
failed(processName(resultID, varName))) {
432 globalVarIDMap[varName] = resultID;
433 operands.push_back(resultID);
436 operands.push_back(
static_cast<uint32_t
>(varOp.storageClass()));
439 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
440 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
441 uint32_t initializerID = 0;
444 varOp->getParentOp(), initRef.getAttr());
447 if (isa<spirv::GlobalVariableOp>(initOp))
448 initializerID = getVariableID(*initSymbolName);
450 initializerID = getSpecConstID(*initSymbolName);
454 "invalid usage of undefined variable as initializer");
456 operands.push_back(initializerID);
457 elidedAttrs.push_back(initAttrName);
460 if (
failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
463 elidedAttrs.push_back(initAttrName);
466 for (
auto attr : varOp->getAttrs()) {
467 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
468 return attr.getName() == elided;
472 if (
failed(processDecoration(varOp.getLoc(), resultID, attr))) {
479 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
482 auto &body = selectionOp.getBody();
483 for (
Block &block : body)
484 getOrCreateBlockID(&block);
486 auto *headerBlock = selectionOp.getHeaderBlock();
487 auto *mergeBlock = selectionOp.getMergeBlock();
488 auto headerID = getBlockID(headerBlock);
489 auto mergeID = getBlockID(mergeBlock);
490 auto loc = selectionOp.getLoc();
496 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
497 assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
498 for (
unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
499 selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
511 auto emitSelectionMerge = [&]() {
512 if (
failed(emitDebugLine(functionBody, loc)))
514 lastProcessedWasMergeInst =
true;
516 functionBody, spirv::Opcode::OpSelectionMerge,
517 {mergeID,
static_cast<uint32_t
>(selectionOp.getSelectionControl())});
521 processBlock(headerBlock,
false, emitSelectionMerge)))
528 headerBlock, [&](
Block *block) {
return processBlock(block); },
529 true, {mergeBlock})))
540 if (
failed(emitPhiForBlockArguments(mergeBlock)))
543 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
544 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
545 LLVM_DEBUG(llvm::dbgs() <<
"\n");
549 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
553 auto &body = loopOp.getBody();
554 for (
Block &block : llvm::drop_begin(body))
555 getOrCreateBlockID(&block);
557 auto *headerBlock = loopOp.getHeaderBlock();
558 auto *continueBlock = loopOp.getContinueBlock();
559 auto *mergeBlock = loopOp.getMergeBlock();
560 auto headerID = getBlockID(headerBlock);
561 auto continueID = getBlockID(continueBlock);
562 auto mergeID = getBlockID(mergeBlock);
563 auto loc = loopOp.getLoc();
567 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
568 assert(loopOp.getNumResults() == mergeOp.getNumOperands());
569 for (
unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
570 loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
586 auto emitLoopMerge = [&]() {
587 if (
failed(emitDebugLine(functionBody, loc)))
589 lastProcessedWasMergeInst =
true;
591 functionBody, spirv::Opcode::OpLoopMerge,
592 {mergeID, continueID,
static_cast<uint32_t
>(loopOp.getLoopControl())});
595 if (
failed(processBlock(headerBlock,
false, emitLoopMerge)))
602 headerBlock, [&](
Block *block) {
return processBlock(block); },
603 true, {continueBlock, mergeBlock})))
607 if (
failed(processBlock(continueBlock)))
615 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
616 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
617 LLVM_DEBUG(llvm::dbgs() <<
"\n");
621 LogicalResult Serializer::processBranchConditionalOp(
622 spirv::BranchConditionalOp condBranchOp) {
623 auto conditionID = getValueID(condBranchOp.getCondition());
624 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
625 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
628 if (
auto weights = condBranchOp.getBranchWeights()) {
629 for (
auto val : weights->getValue())
630 arguments.push_back(cast<IntegerAttr>(val).getInt());
633 if (
failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
640 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
641 if (
failed(emitDebugLine(functionBody, branchOp.getLoc())))
644 {getOrCreateBlockID(branchOp.getTarget())});
648 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
649 auto varName = addressOfOp.getVariable();
650 auto variableID = getVariableID(varName);
652 return addressOfOp.emitError(
"unknown result <id> for variable ")
655 valueIDMap[addressOfOp.getPointer()] = variableID;
660 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
661 auto constName = referenceOfOp.getSpecConst();
662 auto constID = getSpecConstID(constName);
664 return referenceOfOp.emitError(
665 "unknown result <id> for specialization constant ")
668 valueIDMap[referenceOfOp.getReference()] = constID;
674 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
677 operands.push_back(
static_cast<uint32_t
>(op.getExecutionModel()));
679 auto funcID = getFunctionID(op.getFn());
681 return op.emitError(
"missing <id> for function ")
683 <<
"; function needs to be defined before spirv.EntryPoint is "
686 operands.push_back(funcID);
691 if (
auto interface = op.getInterface()) {
692 for (
auto var : interface.getValue()) {
693 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
696 "referencing undefined global variable."
697 "spirv.EntryPoint is at the end of spirv.module. All "
698 "referenced variables should already be defined");
700 operands.push_back(
id);
709 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
712 auto funcID = getFunctionID(op.getFn());
714 return op.emitError(
"missing <id> for function ")
716 <<
"; function needs to be serialized before ExecutionModeOp is "
719 operands.push_back(funcID);
721 operands.push_back(
static_cast<uint32_t
>(op.getExecutionMode()));
724 auto values = op.getValues();
726 for (
auto &intVal : values.getValue()) {
727 operands.push_back(
static_cast<uint32_t
>(
728 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
738 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
739 auto funcName = op.getCallee();
740 uint32_t resTypeID = 0;
742 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
743 if (
failed(processType(op.getLoc(), resultTy, resTypeID)))
746 auto funcID = getOrCreateFunctionID(funcName);
747 auto funcCallID = getNextID();
750 for (
auto value : op.getArguments()) {
751 auto valueID = getValueID(value);
752 assert(valueID &&
"cannot find a value for spirv.FunctionCall");
753 operands.push_back(valueID);
756 if (!isa<NoneType>(resultTy))
757 valueIDMap[op.getResult(0)] = funcCallID;
765 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
769 for (
Value operand : op->getOperands()) {
770 auto id = getValueID(operand);
771 assert(
id &&
"use before def!");
772 operands.push_back(
id);
775 StringAttr memoryAccess = op.getMemoryAccessAttrName();
776 if (
auto attr = op->getAttr(memoryAccess)) {
778 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
781 elidedAttrs.push_back(memoryAccess.strref());
783 StringAttr alignment = op.getAlignmentAttrName();
784 if (
auto attr = op->getAttr(alignment)) {
785 operands.push_back(
static_cast<uint32_t
>(
786 cast<IntegerAttr>(attr).getValue().getZExtValue()));
789 elidedAttrs.push_back(alignment.strref());
791 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
792 if (
auto attr = op->getAttr(sourceMemoryAccess)) {
794 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
797 elidedAttrs.push_back(sourceMemoryAccess.strref());
799 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
800 if (
auto attr = op->getAttr(sourceAlignment)) {
801 operands.push_back(
static_cast<uint32_t
>(
802 cast<IntegerAttr>(attr).getValue().getZExtValue()));
805 elidedAttrs.push_back(sourceAlignment.strref());
806 if (
failed(emitDebugLine(functionBody, op.getLoc())))
813 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
814 spirv::GenericCastToPtrExplicitOp op) {
818 uint32_t resultTypeID = 0;
819 uint32_t resultID = 0;
820 resultTy = op->getResult(0).getType();
821 if (
failed(processType(loc, resultTy, resultTypeID)))
823 operands.push_back(resultTypeID);
825 resultID = getNextID();
826 operands.push_back(resultID);
827 valueIDMap[op->getResult(0)] = resultID;
829 for (
Value operand : op->getOperands())
830 operands.push_back(getValueID(operand));
831 spirv::StorageClass resultStorage =
832 cast<spirv::PointerType>(resultTy).getStorageClass();
833 operands.push_back(
static_cast<uint32_t
>(resultStorage));
841 #define GET_SERIALIZATION_FNS
842 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
OpListType & getOperations()
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.