19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "spirv-serialization"
45 bool skipHeader =
false,
BlockRange skipBlocks = {}) {
46 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
49 for (
Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50 if (skipHeader && block == headerBlock)
52 if (failed(blockHandler(block)))
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
62 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63 valueIDMap[op.getResult()] = resultID;
69 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
70 if (
auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
73 if (
auto specID = op->getAttrOfType<IntegerAttr>(
"spec_id")) {
74 auto val =
static_cast<uint32_t
>(specID.getInt());
75 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
79 specConstIDMap[op.getSymName()] = resultID;
80 return processName(resultID, op.getSymName());
86 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
88 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
92 auto resultID = getNextID();
95 operands.push_back(typeID);
96 operands.push_back(resultID);
98 auto constituents = op.getConstituents();
100 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
103 auto constituentName = constituent.getValue();
104 auto constituentID = getSpecConstID(constituentName);
106 if (!constituentID) {
107 return op.emitError(
"unknown result <id> for specialization constant ")
111 operands.push_back(constituentID);
115 spirv::Opcode::OpSpecConstantComposite, operands);
116 specConstIDMap[op.getSymName()] = resultID;
118 return processName(resultID, op.getSymName());
122 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
124 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
128 auto resultID = getNextID();
131 operands.push_back(typeID);
132 operands.push_back(resultID);
134 Block &block = op.getRegion().getBlocks().
front();
137 std::string enclosedOpName;
138 llvm::raw_string_ostream rss(enclosedOpName);
140 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
142 if (!enclosedOpcode) {
143 op.emitError(
"Couldn't find op code for op ")
148 operands.push_back(
static_cast<uint32_t
>(*enclosedOpcode));
152 uint32_t
id = getValueID(operand);
153 assert(
id &&
"use before def!");
154 operands.push_back(
id);
159 valueIDMap[op.getResult()] = resultID;
164 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
165 auto undefType = op.getType();
166 auto &
id = undefValIDMap[undefType];
170 if (failed(processType(op.getLoc(), undefType, typeID)))
175 valueIDMap[op.getResult()] = id;
179 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
181 uint32_t argTypeID = 0;
182 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
185 auto argValueID = getNextID();
188 auto funcOp = cast<FunctionOpInterface>(*op);
189 for (
auto argAttr : funcOp.getArgAttrs(idx)) {
190 if (argAttr.getName() != DecorationAttr::name)
193 if (
auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
194 if (failed(processDecorationAttr(op->getLoc(), argValueID,
195 decAttr.getValue(), decAttr)))
200 valueIDMap[arg] = argValueID;
202 {argTypeID, argValueID});
207 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
208 LLVM_DEBUG(llvm::dbgs() <<
"-- start function '" << op.getName() <<
"' --\n");
209 assert(functionHeader.empty() && functionBody.empty());
211 uint32_t fnTypeID = 0;
213 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
218 uint32_t resTypeID = 0;
219 auto resultTypes = op.getFunctionType().getResults();
220 if (resultTypes.size() > 1) {
221 return op.emitError(
"cannot serialize function with multiple return types");
223 if (failed(processType(op.getLoc(),
224 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
228 operands.push_back(resTypeID);
229 auto funcID = getOrCreateFunctionID(op.getName());
230 operands.push_back(funcID);
231 operands.push_back(
static_cast<uint32_t
>(op.getFunctionControl()));
232 operands.push_back(fnTypeID);
236 if (failed(processName(funcID, op.getName()))) {
241 auto linkageAttr = op.getLinkageAttributes();
242 auto hasImportLinkage =
243 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
244 spirv::LinkageType::Import);
245 if (op.isExternal() && !hasImportLinkage) {
247 "'spirv.module' cannot contain external functions "
248 "without 'Import' linkage_attributes (LinkageAttributes)");
250 if (op.isExternal() && hasImportLinkage) {
260 if (failed(processFuncParameter(op)))
267 if (failed(processFuncParameter(op)))
279 {getOrCreateBlockID(&op.front())});
280 if (failed(processBlock(&op.front(),
true)))
283 &op.front(), [&](
Block *block) { return processBlock(block); },
290 for (
const auto &deferredValue : deferredPhiValues) {
291 Value value = deferredValue.first;
292 uint32_t
id = getValueID(value);
293 LLVM_DEBUG(llvm::dbgs() <<
"[phi] fix reference of value " << value
294 <<
" to id = " <<
id <<
'\n');
295 assert(
id &&
"OpPhi references undefined value!");
296 for (
size_t offset : deferredValue.second)
297 functionBody[offset] = id;
299 deferredPhiValues.clear();
301 LLVM_DEBUG(llvm::dbgs() <<
"-- completed function '" << op.getName()
307 for (
auto attr : op->getAttrs()) {
309 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
310 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
312 if (isValidDecoration != std::nullopt) {
313 if (failed(processDecoration(op.getLoc(), funcID, attr))) {
321 functions.append(functionHeader.begin(), functionHeader.end());
322 functions.append(functionBody.begin(), functionBody.end());
323 functionHeader.clear();
324 functionBody.clear();
329 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
332 uint32_t resultID = 0;
333 uint32_t resultTypeID = 0;
334 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
337 operands.push_back(resultTypeID);
338 resultID = getNextID();
339 valueIDMap[op.getResult()] = resultID;
340 operands.push_back(resultID);
341 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
344 static_cast<uint32_t
>(cast<spirv::StorageClassAttr>(attr).getValue()));
346 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
347 for (
auto arg : op.getODSOperands(0)) {
348 auto argID = getValueID(arg);
350 return emitError(op.getLoc(),
"operand 0 has a use before def");
352 operands.push_back(argID);
354 if (failed(emitDebugLine(functionHeader, op.getLoc())))
357 for (
auto attr : op->getAttrs()) {
358 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
359 return attr.getName() == elided;
363 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
371 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
373 uint32_t resultTypeID = 0;
375 if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
379 elidedAttrs.push_back(
"type");
381 operands.push_back(resultTypeID);
382 auto resultID = getNextID();
385 auto varName = varOp.getSymName();
387 if (failed(processName(resultID, varName))) {
390 globalVarIDMap[varName] = resultID;
391 operands.push_back(resultID);
394 operands.push_back(
static_cast<uint32_t
>(varOp.storageClass()));
397 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
398 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
399 uint32_t initializerID = 0;
402 varOp->getParentOp(), initRef.getAttr());
405 if (isa<spirv::GlobalVariableOp>(initOp))
406 initializerID = getVariableID(*initSymbolName);
408 initializerID = getSpecConstID(*initSymbolName);
412 "invalid usage of undefined variable as initializer");
414 operands.push_back(initializerID);
415 elidedAttrs.push_back(initAttrName);
418 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
421 elidedAttrs.push_back(initAttrName);
424 for (
auto attr : varOp->getAttrs()) {
425 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
426 return attr.getName() == elided;
430 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
437 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
440 auto &body = selectionOp.getBody();
441 for (
Block &block : body)
442 getOrCreateBlockID(&block);
444 auto *headerBlock = selectionOp.getHeaderBlock();
445 auto *mergeBlock = selectionOp.getMergeBlock();
446 auto headerID = getBlockID(headerBlock);
447 auto mergeID = getBlockID(mergeBlock);
448 auto loc = selectionOp.getLoc();
454 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
455 assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
456 for (
unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
457 selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
469 auto emitSelectionMerge = [&]() {
470 if (failed(emitDebugLine(functionBody, loc)))
472 lastProcessedWasMergeInst =
true;
474 functionBody, spirv::Opcode::OpSelectionMerge,
475 {mergeID,
static_cast<uint32_t
>(selectionOp.getSelectionControl())});
479 processBlock(headerBlock,
false, emitSelectionMerge)))
486 headerBlock, [&](
Block *block) {
return processBlock(block); },
487 true, {mergeBlock})))
498 if (failed(emitPhiForBlockArguments(mergeBlock)))
501 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
502 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
503 LLVM_DEBUG(llvm::dbgs() <<
"\n");
507 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
511 auto &body = loopOp.getBody();
512 for (
Block &block : llvm::drop_begin(body))
513 getOrCreateBlockID(&block);
515 auto *headerBlock = loopOp.getHeaderBlock();
516 auto *continueBlock = loopOp.getContinueBlock();
517 auto *mergeBlock = loopOp.getMergeBlock();
518 auto headerID = getBlockID(headerBlock);
519 auto continueID = getBlockID(continueBlock);
520 auto mergeID = getBlockID(mergeBlock);
521 auto loc = loopOp.getLoc();
537 auto emitLoopMerge = [&]() {
538 if (failed(emitDebugLine(functionBody, loc)))
540 lastProcessedWasMergeInst =
true;
542 functionBody, spirv::Opcode::OpLoopMerge,
543 {mergeID, continueID,
static_cast<uint32_t
>(loopOp.getLoopControl())});
546 if (failed(processBlock(headerBlock,
false, emitLoopMerge)))
553 headerBlock, [&](
Block *block) {
return processBlock(block); },
554 true, {continueBlock, mergeBlock})))
558 if (failed(processBlock(continueBlock)))
566 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
567 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
568 LLVM_DEBUG(llvm::dbgs() <<
"\n");
572 LogicalResult Serializer::processBranchConditionalOp(
573 spirv::BranchConditionalOp condBranchOp) {
574 auto conditionID = getValueID(condBranchOp.getCondition());
575 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
576 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
579 if (
auto weights = condBranchOp.getBranchWeights()) {
580 for (
auto val : weights->getValue())
581 arguments.push_back(cast<IntegerAttr>(val).getInt());
584 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
591 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
592 if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
595 {getOrCreateBlockID(branchOp.getTarget())});
599 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
600 auto varName = addressOfOp.getVariable();
601 auto variableID = getVariableID(varName);
603 return addressOfOp.emitError(
"unknown result <id> for variable ")
606 valueIDMap[addressOfOp.getPointer()] = variableID;
611 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
612 auto constName = referenceOfOp.getSpecConst();
613 auto constID = getSpecConstID(constName);
615 return referenceOfOp.emitError(
616 "unknown result <id> for specialization constant ")
619 valueIDMap[referenceOfOp.getReference()] = constID;
625 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
628 operands.push_back(
static_cast<uint32_t
>(op.getExecutionModel()));
630 auto funcID = getFunctionID(op.getFn());
632 return op.emitError(
"missing <id> for function ")
634 <<
"; function needs to be defined before spirv.EntryPoint is "
637 operands.push_back(funcID);
642 if (
auto interface = op.getInterface()) {
643 for (
auto var : interface.getValue()) {
644 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
647 "referencing undefined global variable."
648 "spirv.EntryPoint is at the end of spirv.module. All "
649 "referenced variables should already be defined");
651 operands.push_back(
id);
660 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
663 auto funcID = getFunctionID(op.getFn());
665 return op.emitError(
"missing <id> for function ")
667 <<
"; function needs to be serialized before ExecutionModeOp is "
670 operands.push_back(funcID);
672 operands.push_back(
static_cast<uint32_t
>(op.getExecutionMode()));
675 auto values = op.getValues();
677 for (
auto &intVal : values.getValue()) {
678 operands.push_back(
static_cast<uint32_t
>(
679 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
689 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
690 auto funcName = op.getCallee();
691 uint32_t resTypeID = 0;
693 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
694 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
697 auto funcID = getOrCreateFunctionID(funcName);
698 auto funcCallID = getNextID();
701 for (
auto value : op.getArguments()) {
702 auto valueID = getValueID(value);
703 assert(valueID &&
"cannot find a value for spirv.FunctionCall");
704 operands.push_back(valueID);
707 if (!isa<NoneType>(resultTy))
708 valueIDMap[op.getResult(0)] = funcCallID;
716 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
720 for (
Value operand : op->getOperands()) {
721 auto id = getValueID(operand);
722 assert(
id &&
"use before def!");
723 operands.push_back(
id);
726 StringAttr memoryAccess = op.getMemoryAccessAttrName();
727 if (
auto attr = op->getAttr(memoryAccess)) {
729 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
732 elidedAttrs.push_back(memoryAccess.strref());
734 StringAttr alignment = op.getAlignmentAttrName();
735 if (
auto attr = op->getAttr(alignment)) {
736 operands.push_back(
static_cast<uint32_t
>(
737 cast<IntegerAttr>(attr).getValue().getZExtValue()));
740 elidedAttrs.push_back(alignment.strref());
742 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
743 if (
auto attr = op->getAttr(sourceMemoryAccess)) {
745 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
748 elidedAttrs.push_back(sourceMemoryAccess.strref());
750 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
751 if (
auto attr = op->getAttr(sourceAlignment)) {
752 operands.push_back(
static_cast<uint32_t
>(
753 cast<IntegerAttr>(attr).getValue().getZExtValue()));
756 elidedAttrs.push_back(sourceAlignment.strref());
757 if (failed(emitDebugLine(functionBody, op.getLoc())))
764 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
765 spirv::GenericCastToPtrExplicitOp op) {
769 uint32_t resultTypeID = 0;
770 uint32_t resultID = 0;
771 resultTy = op->getResult(0).getType();
772 if (failed(processType(loc, resultTy, resultTypeID)))
774 operands.push_back(resultTypeID);
776 resultID = getNextID();
777 operands.push_back(resultID);
778 valueIDMap[op->getResult(0)] = resultID;
780 for (
Value operand : op->getOperands())
781 operands.push_back(getValueID(operand));
782 spirv::StorageClass resultStorage =
783 cast<spirv::PointerType>(resultTy).getStorageClass();
784 operands.push_back(
static_cast<uint32_t
>(resultStorage));
792 #define GET_SERIALIZATION_FNS
793 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
OpListType & getOperations()
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.