20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/Debug.h"
27#define DEBUG_TYPE "spirv-deserialization"
35 return static_cast<spirv::Opcode
>(word & 0xffff);
45 return spirv::ConstantOp::create(opBuilder, unknownLoc, constInfo->second,
48 if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
50 return spirv::EXTConstantCompositeReplicateOp::create(
51 opBuilder, unknownLoc, constCompositeReplicateInfo->second,
52 constCompositeReplicateInfo->first);
56 spirv::AddressOfOp::create(opBuilder, unknownLoc, varOp.getType(),
57 SymbolRefAttr::get(varOp.getOperation()));
58 return addressOfOp.getPointer();
61 auto referenceOfOp = spirv::ReferenceOfOp::create(
62 opBuilder, unknownLoc, constOp.getDefaultValue().getType(),
63 SymbolRefAttr::get(constOp.getOperation()));
64 return referenceOfOp.getReference();
66 if (SpecConstantCompositeOp specConstCompositeOp =
68 auto referenceOfOp = spirv::ReferenceOfOp::create(
69 opBuilder, unknownLoc, specConstCompositeOp.getType(),
70 SymbolRefAttr::get(specConstCompositeOp.getOperation()));
71 return referenceOfOp.getReference();
73 if (
auto specConstCompositeReplicateOp =
75 auto referenceOfOp = spirv::ReferenceOfOp::create(
76 opBuilder, unknownLoc, specConstCompositeReplicateOp.getType(),
77 SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
78 return referenceOfOp.getReference();
82 id, specConstOperationInfo->enclodesOpcode,
83 specConstOperationInfo->resultTypeID,
84 specConstOperationInfo->enclosedOpOperands);
87 return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
89 if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
91 IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
92 Type resultType = graphConstantARMInfo->resultType;
93 return spirv::GraphConstantARMOp::create(opBuilder, unknownLoc, resultType,
96 return valueMap.lookup(
id);
101 std::optional<spirv::Opcode> expectedOpcode) {
102 auto binarySize = binary.size();
103 if (curOffset >= binarySize) {
104 return emitError(unknownLoc,
"expected ")
105 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
113 uint32_t wordCount = binary[curOffset] >> 16;
116 return emitError(unknownLoc,
"word count cannot be zero");
118 uint32_t nextOffset = curOffset + wordCount;
119 if (nextOffset > binarySize)
120 return emitError(unknownLoc,
"insufficient words for the last instruction");
123 operands = binary.slice(curOffset + 1, wordCount - 1);
124 curOffset = nextOffset;
130 LLVM_DEBUG(logger.startLine() <<
"[inst] processing instruction "
131 << spirv::stringifyOpcode(opcode) <<
"\n");
136 case spirv::Opcode::OpCapability:
137 return processCapability(operands);
138 case spirv::Opcode::OpExtension:
139 return processExtension(operands);
140 case spirv::Opcode::OpExtInst:
142 case spirv::Opcode::OpExtInstImport:
143 return processExtInstImport(operands);
144 case spirv::Opcode::OpMemberName:
145 return processMemberName(operands);
146 case spirv::Opcode::OpMemoryModel:
147 return processMemoryModel(operands);
148 case spirv::Opcode::OpEntryPoint:
149 case spirv::Opcode::OpExecutionMode:
150 if (deferInstructions) {
151 deferredInstructions.emplace_back(opcode, operands);
155 case spirv::Opcode::OpVariable:
156 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
160 case spirv::Opcode::OpLine:
162 case spirv::Opcode::OpNoLine:
165 case spirv::Opcode::OpName:
166 return processName(operands);
167 case spirv::Opcode::OpString:
169 case spirv::Opcode::OpModuleProcessed:
170 case spirv::Opcode::OpSource:
171 case spirv::Opcode::OpSourceContinued:
172 case spirv::Opcode::OpSourceExtension:
176 case spirv::Opcode::OpTypeVoid:
177 case spirv::Opcode::OpTypeBool:
178 case spirv::Opcode::OpTypeInt:
179 case spirv::Opcode::OpTypeFloat:
180 case spirv::Opcode::OpTypeVector:
181 case spirv::Opcode::OpTypeMatrix:
182 case spirv::Opcode::OpTypeArray:
183 case spirv::Opcode::OpTypeFunction:
184 case spirv::Opcode::OpTypeImage:
185 case spirv::Opcode::OpTypeSampledImage:
186 case spirv::Opcode::OpTypeRuntimeArray:
187 case spirv::Opcode::OpTypeStruct:
188 case spirv::Opcode::OpTypePointer:
189 case spirv::Opcode::OpTypeTensorARM:
190 case spirv::Opcode::OpTypeGraphARM:
191 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
193 case spirv::Opcode::OpTypeForwardPointer:
195 case spirv::Opcode::OpConstant:
197 case spirv::Opcode::OpSpecConstant:
199 case spirv::Opcode::OpConstantComposite:
201 case spirv::Opcode::OpConstantCompositeReplicateEXT:
203 case spirv::Opcode::OpSpecConstantComposite:
205 case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
207 case spirv::Opcode::OpSpecConstantOp:
209 case spirv::Opcode::OpConstantTrue:
211 case spirv::Opcode::OpSpecConstantTrue:
213 case spirv::Opcode::OpConstantFalse:
215 case spirv::Opcode::OpSpecConstantFalse:
217 case spirv::Opcode::OpConstantNull:
219 case spirv::Opcode::OpGraphConstantARM:
221 case spirv::Opcode::OpDecorate:
222 return processDecoration(operands);
223 case spirv::Opcode::OpMemberDecorate:
224 return processMemberDecoration(operands);
225 case spirv::Opcode::OpFunction:
227 case spirv::Opcode::OpGraphEntryPointARM:
228 if (deferInstructions) {
229 deferredInstructions.emplace_back(opcode, operands);
233 case spirv::Opcode::OpGraphARM:
235 case spirv::Opcode::OpGraphSetOutputARM:
237 case spirv::Opcode::OpGraphEndARM:
239 case spirv::Opcode::OpLabel:
241 case spirv::Opcode::OpBranch:
243 case spirv::Opcode::OpBranchConditional:
245 case spirv::Opcode::OpSelectionMerge:
247 case spirv::Opcode::OpLoopMerge:
249 case spirv::Opcode::OpPhi:
251 case spirv::Opcode::OpSwitch:
253 case spirv::Opcode::OpUndef:
263 unsigned numOperands) {
265 uint32_t valueID = 0;
267 size_t wordIndex = 0;
269 if (wordIndex >= words.size())
271 "expected result type <id> while deserializing for ")
275 auto type =
getType(words[wordIndex]);
277 return emitError(unknownLoc,
"unknown type result <id>: ")
279 resultTypes.push_back(type);
283 if (wordIndex >= words.size())
285 "expected result <id> while deserializing for ")
287 valueID = words[wordIndex];
295 size_t operandIndex = 0;
296 for (; operandIndex < numOperands && wordIndex < words.size();
297 ++operandIndex, ++wordIndex) {
298 auto arg =
getValue(words[wordIndex]);
300 return emitError(unknownLoc,
"unknown result <id>: ") << words[wordIndex];
301 operands.push_back(arg);
303 if (operandIndex != numOperands) {
306 "found less operands than expected when deserializing for ")
307 << opName <<
"; only " << operandIndex <<
" of " << numOperands
310 if (wordIndex != words.size()) {
313 "found more operands than expected when deserializing for ")
314 << opName <<
"; only " << wordIndex <<
" of " << words.size()
319 if (decorations.count(valueID)) {
320 auto attrs = decorations[valueID].getAttrs();
321 attributes.append(attrs.begin(), attrs.end());
342 if (operands.size() != 2) {
343 return emitError(unknownLoc,
"OpUndef instruction must have two operands");
345 auto type =
getType(operands[0]);
347 return emitError(unknownLoc,
"unknown type <id> with OpUndef instruction");
349 undefMap[operands[1]] = type;
354 if (operands.size() < 4) {
356 "OpExtInst must have at least 4 operands, result type "
357 "<id>, result <id>, set <id> and instruction opcode");
359 if (!extendedInstSets.count(operands[2])) {
360 return emitError(unknownLoc,
"undefined set <id> in OpExtInst");
363 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
364 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
366 extendedInstSets[operands[2]], operands[3], slicedOperands);
375 unsigned wordIndex = 0;
376 if (wordIndex >= words.size()) {
378 "missing Execution Model specification in OpEntryPoint");
380 auto execModel = spirv::ExecutionModelAttr::get(
381 context,
static_cast<spirv::ExecutionModel
>(words[wordIndex++]));
382 if (wordIndex >= words.size()) {
383 return emitError(unknownLoc,
"missing <id> in OpEntryPoint");
386 auto fnID = words[wordIndex++];
390 auto parsedFunc = getFunction(fnID);
392 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
394 if (parsedFunc.getName() != fnName) {
398 if (!parsedFunc.getName().starts_with(
"spirv_fn_"))
400 "function name mismatch between OpEntryPoint "
401 "and OpFunction with <id> ")
402 << fnID <<
": " << fnName <<
" vs. " << parsedFunc.getName();
403 parsedFunc.setName(fnName);
406 while (wordIndex < words.size()) {
407 auto arg = getGlobalVariable(words[wordIndex]);
409 return emitError(unknownLoc,
"undefined result <id> ")
410 << words[wordIndex] <<
" while decoding OpEntryPoint";
412 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
415 spirv::EntryPointOp::create(
416 opBuilder, unknownLoc, execModel,
417 SymbolRefAttr::get(opBuilder.getContext(), fnName),
418 opBuilder.getArrayAttr(interface));
425 unsigned wordIndex = 0;
426 if (wordIndex >= words.size()) {
428 "missing function result <id> in OpExecutionMode");
431 auto fnID = words[wordIndex++];
432 auto fn = getFunction(fnID);
434 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
437 if (wordIndex >= words.size()) {
438 return emitError(unknownLoc,
"missing Execution Mode in OpExecutionMode");
440 auto execMode = spirv::ExecutionModeAttr::get(
441 context,
static_cast<spirv::ExecutionMode
>(words[wordIndex++]));
445 while (wordIndex < words.size()) {
446 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
448 auto values = opBuilder.getArrayAttr(attrListElems);
449 spirv::ExecutionModeOp::create(
450 opBuilder, unknownLoc,
451 SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
459 if (operands.size() < 3) {
461 "OpFunctionCall must have at least 3 operands");
466 return emitError(unknownLoc,
"undefined result type from <id> ")
471 if (isVoidType(resultType))
472 resultType =
nullptr;
474 auto resultID = operands[1];
475 auto functionID = operands[2];
477 auto functionName = getFunctionSymbol(functionID);
480 for (
auto operand : llvm::drop_begin(operands, 3)) {
481 auto value = getValue(operand);
483 return emitError(unknownLoc,
"unknown <id> ")
484 << operand <<
" used by OpFunctionCall";
486 arguments.push_back(value);
489 auto opFunctionCall = spirv::FunctionCallOp::create(
490 opBuilder, unknownLoc, resultType,
491 SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
494 valueMap[resultID] = opFunctionCall.getResult(0);
502 size_t wordIndex = 0;
506 if (wordIndex < words.size()) {
507 auto arg = getValue(words[wordIndex]);
510 return emitError(unknownLoc,
"unknown result <id> : ")
514 operands.push_back(arg);
518 if (wordIndex < words.size()) {
519 auto arg = getValue(words[wordIndex]);
522 return emitError(unknownLoc,
"unknown result <id> : ")
526 operands.push_back(arg);
530 bool isAlignedAttr =
false;
532 if (wordIndex < words.size()) {
533 auto attrValue = words[wordIndex++];
534 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
535 static_cast<spirv::MemoryAccess
>(attrValue));
536 attributes.push_back(
538 isAlignedAttr = (attrValue == 2);
541 if (isAlignedAttr && wordIndex < words.size()) {
542 attributes.push_back(opBuilder.getNamedAttr(
543 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
546 if (wordIndex < words.size()) {
547 auto attrValue = words[wordIndex++];
548 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
549 static_cast<spirv::MemoryAccess
>(attrValue));
550 attributes.push_back(opBuilder.getNamedAttr(
"source_memory_access", attr));
553 if (wordIndex < words.size()) {
554 attributes.push_back(opBuilder.getNamedAttr(
555 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
558 if (wordIndex != words.size()) {
560 "found more operands than expected when deserializing "
561 "spirv::CopyMemoryOp, only ")
562 << wordIndex <<
" of " << words.size() <<
" processed";
565 Location loc = createFileLineColLoc(opBuilder);
566 spirv::CopyMemoryOp::create(opBuilder, loc, resultTypes, operands,
575 if (words.size() != 4) {
577 "expected 4 words in GenericCastToPtrExplicitOp"
583 uint32_t valueID = 0;
587 return emitError(unknownLoc,
"unknown type result <id> : ") << words[0];
588 resultTypes.push_back(type);
592 auto arg = getValue(words[2]);
594 return emitError(unknownLoc,
"unknown result <id> : ") << words[2];
595 operands.push_back(arg);
597 Location loc = createFileLineColLoc(opBuilder);
598 Operation *op = spirv::GenericCastToPtrExplicitOp::create(
599 opBuilder, loc, resultTypes, operands);
606#define GET_DESERIALIZATION_FNS
607#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processExtInst(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpExtInst with given operands.
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName, uint32_t instructionID, ArrayRef< uint32_t > words)
Dispatches the deserialization of extended instruction set operation based on the extended instructio...
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
spirv::EXTSpecConstantCompositeReplicateOp getSpecConstantCompositeReplicate(uint32_t id)
Gets the replicated composite specialization constant with the given result <id>.
LogicalResult processOp(ArrayRef< uint32_t > words)
Method to deserialize an operation in the SPIR-V dialect that is a mirror of an instruction in the SP...
Type getUndefType(uint32_t id)
Get the type associated with the result <id> of an OpUndef.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, ArrayRef< uint32_t > words)
Method to dispatch to the specialized deserialization function for an operation in SPIR-V dialect tha...
LogicalResult processOpWithoutGrammarAttr(ArrayRef< uint32_t > words, StringRef opName, bool hasResult, unsigned numOperands)
Processes a SPIR-V instruction from the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processUndef(ArrayRef< uint32_t > operands)
Processes a OpUndef instruction.
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
constexpr StringRef attributeName()
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)