20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "spirv-deserialization"
35 return static_cast<spirv::Opcode
>(word & 0xffff);
42 Value spirv::Deserializer::getValue(uint32_t
id) {
43 if (
auto constInfo = getConstant(
id)) {
45 return opBuilder.
create<spirv::ConstantOp>(unknownLoc, constInfo->second,
48 if (
auto varOp = getGlobalVariable(
id)) {
49 auto addressOfOp = opBuilder.
create<spirv::AddressOfOp>(
51 return addressOfOp.getPointer();
53 if (
auto constOp = getSpecConstant(
id)) {
54 auto referenceOfOp = opBuilder.
create<spirv::ReferenceOfOp>(
55 unknownLoc, constOp.getDefaultValue().getType(),
57 return referenceOfOp.getReference();
59 if (
auto constCompositeOp = getSpecConstantComposite(
id)) {
60 auto referenceOfOp = opBuilder.
create<spirv::ReferenceOfOp>(
61 unknownLoc, constCompositeOp.getType(),
63 return referenceOfOp.getReference();
65 if (
auto specConstOperationInfo = getSpecConstantOperation(
id)) {
66 return materializeSpecConstantOperation(
67 id, specConstOperationInfo->enclodesOpcode,
68 specConstOperationInfo->resultTypeID,
69 specConstOperationInfo->enclosedOpOperands);
71 if (
auto undef = getUndefType(
id)) {
72 return opBuilder.
create<spirv::UndefOp>(unknownLoc, undef);
74 return valueMap.lookup(
id);
79 std::optional<spirv::Opcode> expectedOpcode) {
80 auto binarySize = binary.size();
81 if (curOffset >= binarySize) {
83 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
91 uint32_t wordCount = binary[curOffset] >> 16;
94 return emitError(unknownLoc,
"word count cannot be zero");
96 uint32_t nextOffset = curOffset + wordCount;
97 if (nextOffset > binarySize)
98 return emitError(unknownLoc,
"insufficient words for the last instruction");
101 operands = binary.slice(curOffset + 1, wordCount - 1);
102 curOffset = nextOffset;
108 LLVM_DEBUG(logger.startLine() <<
"[inst] processing instruction "
109 << spirv::stringifyOpcode(opcode) <<
"\n");
114 case spirv::Opcode::OpCapability:
115 return processCapability(operands);
116 case spirv::Opcode::OpExtension:
117 return processExtension(operands);
118 case spirv::Opcode::OpExtInst:
119 return processExtInst(operands);
120 case spirv::Opcode::OpExtInstImport:
121 return processExtInstImport(operands);
122 case spirv::Opcode::OpMemberName:
123 return processMemberName(operands);
124 case spirv::Opcode::OpMemoryModel:
125 return processMemoryModel(operands);
126 case spirv::Opcode::OpEntryPoint:
127 case spirv::Opcode::OpExecutionMode:
128 if (deferInstructions) {
129 deferredInstructions.emplace_back(opcode, operands);
133 case spirv::Opcode::OpVariable:
134 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135 return processGlobalVariable(operands);
138 case spirv::Opcode::OpLine:
139 return processDebugLine(operands);
140 case spirv::Opcode::OpNoLine:
143 case spirv::Opcode::OpName:
144 return processName(operands);
145 case spirv::Opcode::OpString:
146 return processDebugString(operands);
147 case spirv::Opcode::OpModuleProcessed:
148 case spirv::Opcode::OpSource:
149 case spirv::Opcode::OpSourceContinued:
150 case spirv::Opcode::OpSourceExtension:
154 case spirv::Opcode::OpTypeVoid:
155 case spirv::Opcode::OpTypeBool:
156 case spirv::Opcode::OpTypeInt:
157 case spirv::Opcode::OpTypeFloat:
158 case spirv::Opcode::OpTypeVector:
159 case spirv::Opcode::OpTypeMatrix:
160 case spirv::Opcode::OpTypeArray:
161 case spirv::Opcode::OpTypeFunction:
162 case spirv::Opcode::OpTypeImage:
163 case spirv::Opcode::OpTypeSampledImage:
164 case spirv::Opcode::OpTypeRuntimeArray:
165 case spirv::Opcode::OpTypeStruct:
166 case spirv::Opcode::OpTypePointer:
167 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168 case spirv::Opcode::OpTypeCooperativeMatrixNV:
169 return processType(opcode, operands);
170 case spirv::Opcode::OpTypeForwardPointer:
171 return processTypeForwardPointer(operands);
172 case spirv::Opcode::OpTypeJointMatrixINTEL:
173 return processType(opcode, operands);
174 case spirv::Opcode::OpConstant:
175 return processConstant(operands,
false);
176 case spirv::Opcode::OpSpecConstant:
177 return processConstant(operands,
true);
178 case spirv::Opcode::OpConstantComposite:
179 return processConstantComposite(operands);
180 case spirv::Opcode::OpSpecConstantComposite:
181 return processSpecConstantComposite(operands);
182 case spirv::Opcode::OpSpecConstantOp:
183 return processSpecConstantOperation(operands);
184 case spirv::Opcode::OpConstantTrue:
185 return processConstantBool(
true, operands,
false);
186 case spirv::Opcode::OpSpecConstantTrue:
187 return processConstantBool(
true, operands,
true);
188 case spirv::Opcode::OpConstantFalse:
189 return processConstantBool(
false, operands,
false);
190 case spirv::Opcode::OpSpecConstantFalse:
191 return processConstantBool(
false, operands,
true);
192 case spirv::Opcode::OpConstantNull:
193 return processConstantNull(operands);
194 case spirv::Opcode::OpDecorate:
195 return processDecoration(operands);
196 case spirv::Opcode::OpMemberDecorate:
197 return processMemberDecoration(operands);
198 case spirv::Opcode::OpFunction:
199 return processFunction(operands);
200 case spirv::Opcode::OpLabel:
201 return processLabel(operands);
202 case spirv::Opcode::OpBranch:
203 return processBranch(operands);
204 case spirv::Opcode::OpBranchConditional:
205 return processBranchConditional(operands);
206 case spirv::Opcode::OpSelectionMerge:
207 return processSelectionMerge(operands);
208 case spirv::Opcode::OpLoopMerge:
209 return processLoopMerge(operands);
210 case spirv::Opcode::OpPhi:
211 return processPhi(operands);
212 case spirv::Opcode::OpUndef:
213 return processUndef(operands);
217 return dispatchToAutogenDeserialization(opcode, operands);
220 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
222 unsigned numOperands) {
224 uint32_t valueID = 0;
226 size_t wordIndex = 0;
228 if (wordIndex >= words.size())
230 "expected result type <id> while deserializing for ")
234 auto type = getType(words[wordIndex]);
236 return emitError(unknownLoc,
"unknown type result <id>: ")
238 resultTypes.push_back(type);
242 if (wordIndex >= words.size())
244 "expected result <id> while deserializing for ")
246 valueID = words[wordIndex];
254 size_t operandIndex = 0;
255 for (; operandIndex < numOperands && wordIndex < words.size();
256 ++operandIndex, ++wordIndex) {
257 auto arg = getValue(words[wordIndex]);
259 return emitError(unknownLoc,
"unknown result <id>: ") << words[wordIndex];
260 operands.push_back(arg);
262 if (operandIndex != numOperands) {
265 "found less operands than expected when deserializing for ")
266 << opName <<
"; only " << operandIndex <<
" of " << numOperands
269 if (wordIndex != words.size()) {
272 "found more operands than expected when deserializing for ")
273 << opName <<
"; only " << wordIndex <<
" of " << words.size()
278 if (decorations.count(valueID)) {
279 auto attrs = decorations[valueID].getAttrs();
280 attributes.append(attrs.begin(), attrs.end());
284 Location loc = createFileLineColLoc(opBuilder);
286 opState.addOperands(operands);
288 opState.addTypes(resultTypes);
289 opState.addAttributes(attributes);
301 if (operands.size() != 2) {
302 return emitError(unknownLoc,
"OpUndef instruction must have two operands");
304 auto type = getType(operands[0]);
306 return emitError(unknownLoc,
"unknown type <id> with OpUndef instruction");
308 undefMap[operands[1]] = type;
313 if (operands.size() < 4) {
315 "OpExtInst must have at least 4 operands, result type "
316 "<id>, result <id>, set <id> and instruction opcode");
318 if (!extendedInstSets.count(operands[2])) {
319 return emitError(unknownLoc,
"undefined set <id> in OpExtInst");
322 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
323 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
324 return dispatchToExtensionSetAutogenDeserialization(
325 extendedInstSets[operands[2]], operands[3], slicedOperands);
334 unsigned wordIndex = 0;
335 if (wordIndex >= words.size()) {
337 "missing Execution Model specification in OpEntryPoint");
340 context,
static_cast<spirv::ExecutionModel
>(words[wordIndex++]));
341 if (wordIndex >= words.size()) {
342 return emitError(unknownLoc,
"missing <id> in OpEntryPoint");
345 auto fnID = words[wordIndex++];
349 auto parsedFunc = getFunction(fnID);
351 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
353 if (parsedFunc.getName() != fnName) {
357 if (!parsedFunc.getName().startswith(
"spirv_fn_"))
359 "function name mismatch between OpEntryPoint "
360 "and OpFunction with <id> ")
361 << fnID <<
": " << fnName <<
" vs. " << parsedFunc.getName();
362 parsedFunc.setName(fnName);
365 while (wordIndex < words.size()) {
366 auto arg = getGlobalVariable(words[wordIndex]);
368 return emitError(unknownLoc,
"undefined result <id> ")
369 << words[wordIndex] <<
" while decoding OpEntryPoint";
374 opBuilder.create<spirv::EntryPointOp>(
376 opBuilder.getArrayAttr(interface));
383 unsigned wordIndex = 0;
384 if (wordIndex >= words.size()) {
386 "missing function result <id> in OpExecutionMode");
389 auto fnID = words[wordIndex++];
390 auto fn = getFunction(fnID);
392 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
395 if (wordIndex >= words.size()) {
396 return emitError(unknownLoc,
"missing Execution Mode in OpExecutionMode");
399 context,
static_cast<spirv::ExecutionMode
>(words[wordIndex++]));
403 while (wordIndex < words.size()) {
404 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
406 auto values = opBuilder.getArrayAttr(attrListElems);
407 opBuilder.create<spirv::ExecutionModeOp>(
416 if (operands.size() < 3) {
418 "OpFunctionCall must have at least 3 operands");
421 Type resultType = getType(operands[0]);
423 return emitError(unknownLoc,
"undefined result type from <id> ")
428 if (isVoidType(resultType))
429 resultType =
nullptr;
431 auto resultID = operands[1];
432 auto functionID = operands[2];
434 auto functionName = getFunctionSymbol(functionID);
437 for (
auto operand : llvm::drop_begin(operands, 3)) {
438 auto value = getValue(operand);
440 return emitError(unknownLoc,
"unknown <id> ")
441 << operand <<
" used by OpFunctionCall";
443 arguments.push_back(value);
446 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
447 unknownLoc, resultType,
451 valueMap[resultID] = opFunctionCall.getResult(0);
459 size_t wordIndex = 0;
463 if (wordIndex < words.size()) {
464 auto arg = getValue(words[wordIndex]);
467 return emitError(unknownLoc,
"unknown result <id> : ")
471 operands.push_back(arg);
475 if (wordIndex < words.size()) {
476 auto arg = getValue(words[wordIndex]);
479 return emitError(unknownLoc,
"unknown result <id> : ")
483 operands.push_back(arg);
487 bool isAlignedAttr =
false;
489 if (wordIndex < words.size()) {
490 auto attrValue = words[wordIndex++];
491 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
492 static_cast<spirv::MemoryAccess
>(attrValue));
493 attributes.push_back(opBuilder.getNamedAttr(
"memory_access", attr));
494 isAlignedAttr = (attrValue == 2);
497 if (isAlignedAttr && wordIndex < words.size()) {
498 attributes.push_back(opBuilder.getNamedAttr(
499 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
502 if (wordIndex < words.size()) {
503 auto attrValue = words[wordIndex++];
504 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
505 static_cast<spirv::MemoryAccess
>(attrValue));
506 attributes.push_back(opBuilder.getNamedAttr(
"source_memory_access", attr));
509 if (wordIndex < words.size()) {
510 attributes.push_back(opBuilder.getNamedAttr(
511 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
514 if (wordIndex != words.size()) {
516 "found more operands than expected when deserializing "
517 "spirv::CopyMemoryOp, only ")
518 << wordIndex <<
" of " << words.size() <<
" processed";
521 Location loc = createFileLineColLoc(opBuilder);
522 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
530 if (words.size() != 4) {
532 "expected 4 words in GenericCastToPtrExplicitOp"
538 uint32_t valueID = 0;
539 auto type = getType(words[0]);
542 return emitError(unknownLoc,
"unknown type result <id> : ") << words[0];
543 resultTypes.push_back(type);
547 auto arg = getValue(words[2]);
549 return emitError(unknownLoc,
"unknown result <id> : ") << words[2];
550 operands.push_back(arg);
552 Location loc = createFileLineColLoc(opBuilder);
554 loc, resultTypes, operands);
561 #define GET_DESERIALIZATION_FNS
562 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.