20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "spirv-deserialization"
35 return static_cast<spirv::Opcode
>(word & 0xffff);
42 Value spirv::Deserializer::getValue(uint32_t
id) {
43 if (
auto constInfo = getConstant(
id)) {
45 return opBuilder.
create<spirv::ConstantOp>(unknownLoc, constInfo->second,
48 if (
auto varOp = getGlobalVariable(
id)) {
49 auto addressOfOp = opBuilder.
create<spirv::AddressOfOp>(
51 return addressOfOp.getPointer();
53 if (
auto constOp = getSpecConstant(
id)) {
54 auto referenceOfOp = opBuilder.
create<spirv::ReferenceOfOp>(
55 unknownLoc, constOp.getDefaultValue().getType(),
57 return referenceOfOp.getReference();
59 if (
auto constCompositeOp = getSpecConstantComposite(
id)) {
60 auto referenceOfOp = opBuilder.
create<spirv::ReferenceOfOp>(
61 unknownLoc, constCompositeOp.getType(),
63 return referenceOfOp.getReference();
65 if (
auto specConstOperationInfo = getSpecConstantOperation(
id)) {
66 return materializeSpecConstantOperation(
67 id, specConstOperationInfo->enclodesOpcode,
68 specConstOperationInfo->resultTypeID,
69 specConstOperationInfo->enclosedOpOperands);
71 if (
auto undef = getUndefType(
id)) {
72 return opBuilder.
create<spirv::UndefOp>(unknownLoc, undef);
74 return valueMap.lookup(
id);
79 std::optional<spirv::Opcode> expectedOpcode) {
80 auto binarySize = binary.size();
81 if (curOffset >= binarySize) {
83 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
91 uint32_t wordCount = binary[curOffset] >> 16;
94 return emitError(unknownLoc,
"word count cannot be zero");
96 uint32_t nextOffset = curOffset + wordCount;
97 if (nextOffset > binarySize)
98 return emitError(unknownLoc,
"insufficient words for the last instruction");
101 operands = binary.slice(curOffset + 1, wordCount - 1);
102 curOffset = nextOffset;
108 LLVM_DEBUG(logger.startLine() <<
"[inst] processing instruction "
109 << spirv::stringifyOpcode(opcode) <<
"\n");
114 case spirv::Opcode::OpCapability:
115 return processCapability(operands);
116 case spirv::Opcode::OpExtension:
117 return processExtension(operands);
118 case spirv::Opcode::OpExtInst:
119 return processExtInst(operands);
120 case spirv::Opcode::OpExtInstImport:
121 return processExtInstImport(operands);
122 case spirv::Opcode::OpMemberName:
123 return processMemberName(operands);
124 case spirv::Opcode::OpMemoryModel:
125 return processMemoryModel(operands);
126 case spirv::Opcode::OpEntryPoint:
127 case spirv::Opcode::OpExecutionMode:
128 if (deferInstructions) {
129 deferredInstructions.emplace_back(opcode, operands);
133 case spirv::Opcode::OpVariable:
134 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135 return processGlobalVariable(operands);
138 case spirv::Opcode::OpLine:
139 return processDebugLine(operands);
140 case spirv::Opcode::OpNoLine:
143 case spirv::Opcode::OpName:
144 return processName(operands);
145 case spirv::Opcode::OpString:
146 return processDebugString(operands);
147 case spirv::Opcode::OpModuleProcessed:
148 case spirv::Opcode::OpSource:
149 case spirv::Opcode::OpSourceContinued:
150 case spirv::Opcode::OpSourceExtension:
154 case spirv::Opcode::OpTypeVoid:
155 case spirv::Opcode::OpTypeBool:
156 case spirv::Opcode::OpTypeInt:
157 case spirv::Opcode::OpTypeFloat:
158 case spirv::Opcode::OpTypeVector:
159 case spirv::Opcode::OpTypeMatrix:
160 case spirv::Opcode::OpTypeArray:
161 case spirv::Opcode::OpTypeFunction:
162 case spirv::Opcode::OpTypeImage:
163 case spirv::Opcode::OpTypeSampledImage:
164 case spirv::Opcode::OpTypeRuntimeArray:
165 case spirv::Opcode::OpTypeStruct:
166 case spirv::Opcode::OpTypePointer:
167 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168 return processType(opcode, operands);
169 case spirv::Opcode::OpTypeForwardPointer:
170 return processTypeForwardPointer(operands);
171 case spirv::Opcode::OpTypeJointMatrixINTEL:
172 return processType(opcode, operands);
173 case spirv::Opcode::OpConstant:
174 return processConstant(operands,
false);
175 case spirv::Opcode::OpSpecConstant:
176 return processConstant(operands,
true);
177 case spirv::Opcode::OpConstantComposite:
178 return processConstantComposite(operands);
179 case spirv::Opcode::OpSpecConstantComposite:
180 return processSpecConstantComposite(operands);
181 case spirv::Opcode::OpSpecConstantOp:
182 return processSpecConstantOperation(operands);
183 case spirv::Opcode::OpConstantTrue:
184 return processConstantBool(
true, operands,
false);
185 case spirv::Opcode::OpSpecConstantTrue:
186 return processConstantBool(
true, operands,
true);
187 case spirv::Opcode::OpConstantFalse:
188 return processConstantBool(
false, operands,
false);
189 case spirv::Opcode::OpSpecConstantFalse:
190 return processConstantBool(
false, operands,
true);
191 case spirv::Opcode::OpConstantNull:
192 return processConstantNull(operands);
193 case spirv::Opcode::OpDecorate:
194 return processDecoration(operands);
195 case spirv::Opcode::OpMemberDecorate:
196 return processMemberDecoration(operands);
197 case spirv::Opcode::OpFunction:
198 return processFunction(operands);
199 case spirv::Opcode::OpLabel:
200 return processLabel(operands);
201 case spirv::Opcode::OpBranch:
202 return processBranch(operands);
203 case spirv::Opcode::OpBranchConditional:
204 return processBranchConditional(operands);
205 case spirv::Opcode::OpSelectionMerge:
206 return processSelectionMerge(operands);
207 case spirv::Opcode::OpLoopMerge:
208 return processLoopMerge(operands);
209 case spirv::Opcode::OpPhi:
210 return processPhi(operands);
211 case spirv::Opcode::OpUndef:
212 return processUndef(operands);
216 return dispatchToAutogenDeserialization(opcode, operands);
219 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
221 unsigned numOperands) {
223 uint32_t valueID = 0;
225 size_t wordIndex = 0;
227 if (wordIndex >= words.size())
229 "expected result type <id> while deserializing for ")
233 auto type = getType(words[wordIndex]);
235 return emitError(unknownLoc,
"unknown type result <id>: ")
237 resultTypes.push_back(type);
241 if (wordIndex >= words.size())
243 "expected result <id> while deserializing for ")
245 valueID = words[wordIndex];
253 size_t operandIndex = 0;
254 for (; operandIndex < numOperands && wordIndex < words.size();
255 ++operandIndex, ++wordIndex) {
256 auto arg = getValue(words[wordIndex]);
258 return emitError(unknownLoc,
"unknown result <id>: ") << words[wordIndex];
259 operands.push_back(arg);
261 if (operandIndex != numOperands) {
264 "found less operands than expected when deserializing for ")
265 << opName <<
"; only " << operandIndex <<
" of " << numOperands
268 if (wordIndex != words.size()) {
271 "found more operands than expected when deserializing for ")
272 << opName <<
"; only " << wordIndex <<
" of " << words.size()
277 if (decorations.count(valueID)) {
278 auto attrs = decorations[valueID].getAttrs();
279 attributes.append(attrs.begin(), attrs.end());
283 Location loc = createFileLineColLoc(opBuilder);
285 opState.addOperands(operands);
287 opState.addTypes(resultTypes);
288 opState.addAttributes(attributes);
300 if (operands.size() != 2) {
301 return emitError(unknownLoc,
"OpUndef instruction must have two operands");
303 auto type = getType(operands[0]);
305 return emitError(unknownLoc,
"unknown type <id> with OpUndef instruction");
307 undefMap[operands[1]] = type;
312 if (operands.size() < 4) {
314 "OpExtInst must have at least 4 operands, result type "
315 "<id>, result <id>, set <id> and instruction opcode");
317 if (!extendedInstSets.count(operands[2])) {
318 return emitError(unknownLoc,
"undefined set <id> in OpExtInst");
321 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
322 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
323 return dispatchToExtensionSetAutogenDeserialization(
324 extendedInstSets[operands[2]], operands[3], slicedOperands);
333 unsigned wordIndex = 0;
334 if (wordIndex >= words.size()) {
336 "missing Execution Model specification in OpEntryPoint");
339 context,
static_cast<spirv::ExecutionModel
>(words[wordIndex++]));
340 if (wordIndex >= words.size()) {
341 return emitError(unknownLoc,
"missing <id> in OpEntryPoint");
344 auto fnID = words[wordIndex++];
348 auto parsedFunc = getFunction(fnID);
350 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
352 if (parsedFunc.getName() != fnName) {
356 if (!parsedFunc.getName().starts_with(
"spirv_fn_"))
358 "function name mismatch between OpEntryPoint "
359 "and OpFunction with <id> ")
360 << fnID <<
": " << fnName <<
" vs. " << parsedFunc.getName();
361 parsedFunc.setName(fnName);
364 while (wordIndex < words.size()) {
365 auto arg = getGlobalVariable(words[wordIndex]);
367 return emitError(unknownLoc,
"undefined result <id> ")
368 << words[wordIndex] <<
" while decoding OpEntryPoint";
373 opBuilder.create<spirv::EntryPointOp>(
375 opBuilder.getArrayAttr(interface));
382 unsigned wordIndex = 0;
383 if (wordIndex >= words.size()) {
385 "missing function result <id> in OpExecutionMode");
388 auto fnID = words[wordIndex++];
389 auto fn = getFunction(fnID);
391 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
394 if (wordIndex >= words.size()) {
395 return emitError(unknownLoc,
"missing Execution Mode in OpExecutionMode");
398 context,
static_cast<spirv::ExecutionMode
>(words[wordIndex++]));
402 while (wordIndex < words.size()) {
403 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
405 auto values = opBuilder.getArrayAttr(attrListElems);
406 opBuilder.create<spirv::ExecutionModeOp>(
415 if (operands.size() < 3) {
417 "OpFunctionCall must have at least 3 operands");
420 Type resultType = getType(operands[0]);
422 return emitError(unknownLoc,
"undefined result type from <id> ")
427 if (isVoidType(resultType))
428 resultType =
nullptr;
430 auto resultID = operands[1];
431 auto functionID = operands[2];
433 auto functionName = getFunctionSymbol(functionID);
436 for (
auto operand : llvm::drop_begin(operands, 3)) {
437 auto value = getValue(operand);
439 return emitError(unknownLoc,
"unknown <id> ")
440 << operand <<
" used by OpFunctionCall";
442 arguments.push_back(value);
445 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
446 unknownLoc, resultType,
450 valueMap[resultID] = opFunctionCall.getResult(0);
458 size_t wordIndex = 0;
462 if (wordIndex < words.size()) {
463 auto arg = getValue(words[wordIndex]);
466 return emitError(unknownLoc,
"unknown result <id> : ")
470 operands.push_back(arg);
474 if (wordIndex < words.size()) {
475 auto arg = getValue(words[wordIndex]);
478 return emitError(unknownLoc,
"unknown result <id> : ")
482 operands.push_back(arg);
486 bool isAlignedAttr =
false;
488 if (wordIndex < words.size()) {
489 auto attrValue = words[wordIndex++];
490 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
491 static_cast<spirv::MemoryAccess
>(attrValue));
492 attributes.push_back(
493 opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
494 isAlignedAttr = (attrValue == 2);
497 if (isAlignedAttr && wordIndex < words.size()) {
498 attributes.push_back(opBuilder.getNamedAttr(
499 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
502 if (wordIndex < words.size()) {
503 auto attrValue = words[wordIndex++];
504 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
505 static_cast<spirv::MemoryAccess
>(attrValue));
506 attributes.push_back(opBuilder.getNamedAttr(
"source_memory_access", attr));
509 if (wordIndex < words.size()) {
510 attributes.push_back(opBuilder.getNamedAttr(
511 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
514 if (wordIndex != words.size()) {
516 "found more operands than expected when deserializing "
517 "spirv::CopyMemoryOp, only ")
518 << wordIndex <<
" of " << words.size() <<
" processed";
521 Location loc = createFileLineColLoc(opBuilder);
522 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
530 if (words.size() != 4) {
532 "expected 4 words in GenericCastToPtrExplicitOp"
538 uint32_t valueID = 0;
539 auto type = getType(words[0]);
542 return emitError(unknownLoc,
"unknown type result <id> : ") << words[0];
543 resultTypes.push_back(type);
547 auto arg = getValue(words[2]);
549 return emitError(unknownLoc,
"unknown result <id> : ") << words[2];
550 operands.push_back(arg);
552 Location loc = createFileLineColLoc(opBuilder);
554 loc, resultTypes, operands);
561 #define GET_DESERIALIZATION_FNS
562 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.