20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "spirv-serialization"
46 bool skipHeader =
false,
BlockRange skipBlocks = {}) {
47 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
48 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
50 for (
Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
51 if (skipHeader && block == headerBlock)
53 if (
failed(blockHandler(block)))
61 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
63 prepareConstant(op.
getLoc(), op.getType(), op.getValue())) {
70 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
71 if (
auto resultID = prepareConstantScalar(op.
getLoc(), op.getDefaultValue(),
74 if (
auto specID = op->
getAttrOfType<IntegerAttr>(
"spec_id")) {
75 auto val =
static_cast<uint32_t
>(specID.getInt());
76 if (
failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
80 specConstIDMap[op.getSymName()] = resultID;
81 return processName(resultID, op.getSymName());
87 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
89 if (
failed(processType(op.
getLoc(), op.getType(), typeID))) {
93 auto resultID = getNextID();
96 operands.push_back(typeID);
97 operands.push_back(resultID);
99 auto constituents = op.getConstituents();
101 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
102 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
104 auto constituentName = constituent.getValue();
105 auto constituentID = getSpecConstID(constituentName);
107 if (!constituentID) {
108 return op.
emitError(
"unknown result <id> for specialization constant ")
112 operands.push_back(constituentID);
116 spirv::Opcode::OpSpecConstantComposite, operands);
117 specConstIDMap[op.getSymName()] = resultID;
119 return processName(resultID, op.getSymName());
123 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
125 if (
failed(processType(op.
getLoc(), op.getType(), typeID))) {
129 auto resultID = getNextID();
132 operands.push_back(typeID);
133 operands.push_back(resultID);
138 std::string enclosedOpName;
139 llvm::raw_string_ostream rss(enclosedOpName);
141 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
143 if (!enclosedOpcode) {
144 op.
emitError(
"Couldn't find op code for op ")
149 operands.push_back(
static_cast<uint32_t
>(*enclosedOpcode));
153 uint32_t
id = getValueID(operand);
154 assert(
id &&
"use before def!");
155 operands.push_back(
id);
165 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
166 auto undefType = op.getType();
167 auto &
id = undefValIDMap[undefType];
171 if (
failed(processType(op.
getLoc(), undefType, typeID)))
181 LLVM_DEBUG(llvm::dbgs() <<
"-- start function '" << op.
getName() <<
"' --\n");
182 assert(functionHeader.empty() && functionBody.empty());
184 uint32_t fnTypeID = 0;
186 if (
failed(processType(op.
getLoc(), op.getFunctionType(), fnTypeID)))
191 uint32_t resTypeID = 0;
192 auto resultTypes = op.getFunctionType().
getResults();
193 if (resultTypes.size() > 1) {
194 return op.
emitError(
"cannot serialize function with multiple return types");
197 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
201 operands.push_back(resTypeID);
202 auto funcID = getOrCreateFunctionID(op.
getName());
203 operands.push_back(funcID);
204 operands.push_back(
static_cast<uint32_t
>(op.getFunctionControl()));
205 operands.push_back(fnTypeID);
214 auto linkageAttr = op.getLinkageAttributes();
215 auto hasImportLinkage =
216 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
217 spirv::LinkageType::Import);
218 if (op.isExternal() && !hasImportLinkage) {
220 "'spirv.module' cannot contain external functions "
221 "without 'Import' linkage_attributes (LinkageAttributes)");
222 }
else if (op.isExternal() && hasImportLinkage) {
232 for (
auto arg : op.getArguments()) {
233 uint32_t argTypeID = 0;
234 if (
failed(processType(op.
getLoc(), arg.getType(), argTypeID))) {
237 auto argValueID = getNextID();
238 valueIDMap[arg] = argValueID;
240 {argTypeID, argValueID});
248 for (
auto arg : op.getArguments()) {
249 uint32_t argTypeID = 0;
250 if (
failed(processType(op.
getLoc(), arg.getType(), argTypeID))) {
253 auto argValueID = getNextID();
254 valueIDMap[arg] = argValueID;
256 {argTypeID, argValueID});
268 {getOrCreateBlockID(&op.front())});
269 if (
failed(processBlock(&op.front(),
true)))
272 &op.front(), [&](
Block *block) { return processBlock(block); },
279 for (
const auto &deferredValue : deferredPhiValues) {
280 Value value = deferredValue.first;
281 uint32_t
id = getValueID(value);
282 LLVM_DEBUG(llvm::dbgs() <<
"[phi] fix reference of value " << value
283 <<
" to id = " <<
id <<
'\n');
284 assert(
id &&
"OpPhi references undefined value!");
285 for (
size_t offset : deferredValue.second)
286 functionBody[offset] = id;
288 deferredPhiValues.clear();
290 LLVM_DEBUG(llvm::dbgs() <<
"-- completed function '" << op.
getName()
298 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
299 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
301 if (isValidDecoration != std::nullopt) {
302 if (
failed(processDecoration(op.
getLoc(), funcID, attr))) {
310 functions.append(functionHeader.begin(), functionHeader.end());
311 functions.append(functionBody.begin(), functionBody.end());
312 functionHeader.clear();
313 functionBody.clear();
318 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
321 uint32_t resultID = 0;
322 uint32_t resultTypeID = 0;
323 if (
failed(processType(op.
getLoc(), op.getType(), resultTypeID))) {
326 operands.push_back(resultTypeID);
327 resultID = getNextID();
329 operands.push_back(resultID);
330 auto attr = op->
getAttr(spirv::attributeName<spirv::StorageClass>());
333 static_cast<uint32_t
>(cast<spirv::StorageClassAttr>(attr).getValue()));
335 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
336 for (
auto arg : op.getODSOperands(0)) {
337 auto argID = getValueID(arg);
341 operands.push_back(argID);
347 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
348 return attr.getName() == elided;
352 if (
failed(processDecoration(op.
getLoc(), resultID, attr))) {
360 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
362 uint32_t resultTypeID = 0;
364 if (
failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
368 elidedAttrs.push_back(
"type");
370 operands.push_back(resultTypeID);
371 auto resultID = getNextID();
374 auto varName = varOp.getSymName();
376 if (
failed(processName(resultID, varName))) {
379 globalVarIDMap[varName] = resultID;
380 operands.push_back(resultID);
383 operands.push_back(
static_cast<uint32_t
>(varOp.storageClass()));
386 if (
auto initializer = varOp.getInitializer()) {
387 auto initializerID = getVariableID(*initializer);
388 if (!initializerID) {
390 "invalid usage of undefined variable as initializer");
392 operands.push_back(initializerID);
393 elidedAttrs.push_back(
"initializer");
396 if (
failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
399 elidedAttrs.push_back(
"initializer");
402 for (
auto attr : varOp->getAttrs()) {
403 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
404 return attr.getName() == elided;
408 if (
failed(processDecoration(varOp.getLoc(), resultID, attr))) {
415 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
418 auto &body = selectionOp.getBody();
419 for (
Block &block : body)
420 getOrCreateBlockID(&block);
422 auto *headerBlock = selectionOp.getHeaderBlock();
423 auto *mergeBlock = selectionOp.getMergeBlock();
424 auto headerID = getBlockID(headerBlock);
425 auto mergeID = getBlockID(mergeBlock);
426 auto loc = selectionOp.getLoc();
438 auto emitSelectionMerge = [&]() {
439 if (
failed(emitDebugLine(functionBody, loc)))
441 lastProcessedWasMergeInst =
true;
443 functionBody, spirv::Opcode::OpSelectionMerge,
444 {mergeID,
static_cast<uint32_t
>(selectionOp.getSelectionControl())});
448 processBlock(headerBlock,
false, emitSelectionMerge)))
455 headerBlock, [&](
Block *block) {
return processBlock(block); },
456 true, {mergeBlock})))
464 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
465 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
466 LLVM_DEBUG(llvm::dbgs() <<
"\n");
470 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
474 auto &body = loopOp.getBody();
475 for (
Block &block : llvm::drop_begin(body))
476 getOrCreateBlockID(&block);
478 auto *headerBlock = loopOp.getHeaderBlock();
479 auto *continueBlock = loopOp.getContinueBlock();
480 auto *mergeBlock = loopOp.getMergeBlock();
481 auto headerID = getBlockID(headerBlock);
482 auto continueID = getBlockID(continueBlock);
483 auto mergeID = getBlockID(mergeBlock);
484 auto loc = loopOp.getLoc();
500 auto emitLoopMerge = [&]() {
501 if (
failed(emitDebugLine(functionBody, loc)))
503 lastProcessedWasMergeInst =
true;
505 functionBody, spirv::Opcode::OpLoopMerge,
506 {mergeID, continueID,
static_cast<uint32_t
>(loopOp.getLoopControl())});
509 if (
failed(processBlock(headerBlock,
false, emitLoopMerge)))
516 headerBlock, [&](
Block *block) {
return processBlock(block); },
517 true, {continueBlock, mergeBlock})))
521 if (
failed(processBlock(continueBlock)))
529 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
530 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
531 LLVM_DEBUG(llvm::dbgs() <<
"\n");
536 spirv::BranchConditionalOp condBranchOp) {
537 auto conditionID = getValueID(condBranchOp.getCondition());
538 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
539 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
542 if (
auto weights = condBranchOp.getBranchWeights()) {
543 for (
auto val : weights->getValue())
544 arguments.push_back(cast<IntegerAttr>(val).getInt());
547 if (
failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
554 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
555 if (
failed(emitDebugLine(functionBody, branchOp.getLoc())))
558 {getOrCreateBlockID(branchOp.getTarget())});
562 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
563 auto varName = addressOfOp.getVariable();
564 auto variableID = getVariableID(varName);
566 return addressOfOp.emitError(
"unknown result <id> for variable ")
569 valueIDMap[addressOfOp.getPointer()] = variableID;
574 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
575 auto constName = referenceOfOp.getSpecConst();
576 auto constID = getSpecConstID(constName);
578 return referenceOfOp.emitError(
579 "unknown result <id> for specialization constant ")
582 valueIDMap[referenceOfOp.getReference()] = constID;
588 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
591 operands.push_back(
static_cast<uint32_t
>(op.getExecutionModel()));
593 auto funcID = getFunctionID(op.getFn());
595 return op.
emitError(
"missing <id> for function ")
597 <<
"; function needs to be defined before spirv.EntryPoint is "
600 operands.push_back(funcID);
605 if (
auto interface = op.getInterface()) {
606 for (
auto var : interface.getValue()) {
607 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
610 "referencing undefined global variable."
611 "spirv.EntryPoint is at the end of spirv.module. All "
612 "referenced variables should already be defined");
614 operands.push_back(
id);
623 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
626 auto funcID = getFunctionID(op.getFn());
628 return op.
emitError(
"missing <id> for function ")
630 <<
"; function needs to be serialized before ExecutionModeOp is "
633 operands.push_back(funcID);
635 operands.push_back(
static_cast<uint32_t
>(op.getExecutionMode()));
638 auto values = op.getValues();
640 for (
auto &intVal : values.getValue()) {
641 operands.push_back(
static_cast<uint32_t
>(
642 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
652 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
653 auto funcName = op.getCallee();
654 uint32_t resTypeID = 0;
657 if (
failed(processType(op.
getLoc(), resultTy, resTypeID)))
660 auto funcID = getOrCreateFunctionID(funcName);
661 auto funcCallID = getNextID();
664 for (
auto value : op.getArguments()) {
665 auto valueID = getValueID(value);
666 assert(valueID &&
"cannot find a value for spirv.FunctionCall");
667 operands.push_back(valueID);
670 if (!isa<NoneType>(resultTy))
671 valueIDMap[op.
getResult(0)] = funcCallID;
679 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
684 auto id = getValueID(operand);
685 assert(
id &&
"use before def!");
686 operands.push_back(
id);
689 if (
auto attr = op->
getAttr(
"memory_access")) {
691 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
694 elidedAttrs.push_back(
"memory_access");
696 if (
auto attr = op->
getAttr(
"alignment")) {
697 operands.push_back(
static_cast<uint32_t
>(
698 cast<IntegerAttr>(attr).getValue().getZExtValue()));
701 elidedAttrs.push_back(
"alignment");
703 if (
auto attr = op->
getAttr(
"source_memory_access")) {
705 static_cast<uint32_t
>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
708 elidedAttrs.push_back(
"source_memory_access");
710 if (
auto attr = op->
getAttr(
"source_alignment")) {
711 operands.push_back(
static_cast<uint32_t
>(
712 cast<IntegerAttr>(attr).getValue().getZExtValue()));
715 elidedAttrs.push_back(
"source_alignment");
724 spirv::GenericCastToPtrExplicitOp op) {
728 uint32_t resultTypeID = 0;
729 uint32_t resultID = 0;
731 if (
failed(processType(loc, resultTy, resultTypeID)))
733 operands.push_back(resultTypeID);
735 resultID = getNextID();
736 operands.push_back(resultID);
740 operands.push_back(getValueID(operand));
741 spirv::StorageClass resultStorage =
742 cast<spirv::PointerType>(resultTy).getStorageClass();
743 operands.push_back(
static_cast<uint32_t
>(resultStorage));
751 #define GET_SERIALIZATION_FNS
752 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
BlockListType & getBlocks()
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.