20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "spirv-deserialization"
35 return static_cast<spirv::Opcode
>(word & 0xffff);
42 Value spirv::Deserializer::getValue(uint32_t
id) {
43 if (
auto constInfo = getConstant(
id)) {
45 return spirv::ConstantOp::create(opBuilder, unknownLoc, constInfo->second,
48 if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
49 getConstantCompositeReplicate(
id)) {
50 return spirv::EXTConstantCompositeReplicateOp::create(
51 opBuilder, unknownLoc, constCompositeReplicateInfo->second,
52 constCompositeReplicateInfo->first);
54 if (
auto varOp = getGlobalVariable(
id)) {
56 spirv::AddressOfOp::create(opBuilder, unknownLoc, varOp.getType(),
58 return addressOfOp.getPointer();
60 if (
auto constOp = getSpecConstant(
id)) {
61 auto referenceOfOp = spirv::ReferenceOfOp::create(
62 opBuilder, unknownLoc, constOp.getDefaultValue().getType(),
64 return referenceOfOp.getReference();
66 if (SpecConstantCompositeOp specConstCompositeOp =
67 getSpecConstantComposite(
id)) {
68 auto referenceOfOp = spirv::ReferenceOfOp::create(
69 opBuilder, unknownLoc, specConstCompositeOp.getType(),
71 return referenceOfOp.getReference();
73 if (
auto specConstCompositeReplicateOp =
74 getSpecConstantCompositeReplicate(
id)) {
75 auto referenceOfOp = spirv::ReferenceOfOp::create(
76 opBuilder, unknownLoc, specConstCompositeReplicateOp.getType(),
78 return referenceOfOp.getReference();
80 if (
auto specConstOperationInfo = getSpecConstantOperation(
id)) {
81 return materializeSpecConstantOperation(
82 id, specConstOperationInfo->enclodesOpcode,
83 specConstOperationInfo->resultTypeID,
84 specConstOperationInfo->enclosedOpOperands);
86 if (
auto undef = getUndefType(
id)) {
87 return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
89 return valueMap.lookup(
id);
92 LogicalResult spirv::Deserializer::sliceInstruction(
94 std::optional<spirv::Opcode> expectedOpcode) {
95 auto binarySize = binary.size();
96 if (curOffset >= binarySize) {
98 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
106 uint32_t wordCount = binary[curOffset] >> 16;
109 return emitError(unknownLoc,
"word count cannot be zero");
111 uint32_t nextOffset = curOffset + wordCount;
112 if (nextOffset > binarySize)
113 return emitError(unknownLoc,
"insufficient words for the last instruction");
116 operands = binary.slice(curOffset + 1, wordCount - 1);
117 curOffset = nextOffset;
121 LogicalResult spirv::Deserializer::processInstruction(
123 LLVM_DEBUG(logger.startLine() <<
"[inst] processing instruction "
124 << spirv::stringifyOpcode(opcode) <<
"\n");
129 case spirv::Opcode::OpCapability:
130 return processCapability(operands);
131 case spirv::Opcode::OpExtension:
132 return processExtension(operands);
133 case spirv::Opcode::OpExtInst:
134 return processExtInst(operands);
135 case spirv::Opcode::OpExtInstImport:
136 return processExtInstImport(operands);
137 case spirv::Opcode::OpMemberName:
138 return processMemberName(operands);
139 case spirv::Opcode::OpMemoryModel:
140 return processMemoryModel(operands);
141 case spirv::Opcode::OpEntryPoint:
142 case spirv::Opcode::OpExecutionMode:
143 if (deferInstructions) {
144 deferredInstructions.emplace_back(opcode, operands);
148 case spirv::Opcode::OpVariable:
149 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
150 return processGlobalVariable(operands);
153 case spirv::Opcode::OpLine:
154 return processDebugLine(operands);
155 case spirv::Opcode::OpNoLine:
158 case spirv::Opcode::OpName:
159 return processName(operands);
160 case spirv::Opcode::OpString:
161 return processDebugString(operands);
162 case spirv::Opcode::OpModuleProcessed:
163 case spirv::Opcode::OpSource:
164 case spirv::Opcode::OpSourceContinued:
165 case spirv::Opcode::OpSourceExtension:
169 case spirv::Opcode::OpTypeVoid:
170 case spirv::Opcode::OpTypeBool:
171 case spirv::Opcode::OpTypeInt:
172 case spirv::Opcode::OpTypeFloat:
173 case spirv::Opcode::OpTypeVector:
174 case spirv::Opcode::OpTypeMatrix:
175 case spirv::Opcode::OpTypeArray:
176 case spirv::Opcode::OpTypeFunction:
177 case spirv::Opcode::OpTypeImage:
178 case spirv::Opcode::OpTypeSampledImage:
179 case spirv::Opcode::OpTypeRuntimeArray:
180 case spirv::Opcode::OpTypeStruct:
181 case spirv::Opcode::OpTypePointer:
182 case spirv::Opcode::OpTypeTensorARM:
183 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
184 return processType(opcode, operands);
185 case spirv::Opcode::OpTypeForwardPointer:
186 return processTypeForwardPointer(operands);
187 case spirv::Opcode::OpConstant:
188 return processConstant(operands,
false);
189 case spirv::Opcode::OpSpecConstant:
190 return processConstant(operands,
true);
191 case spirv::Opcode::OpConstantComposite:
192 return processConstantComposite(operands);
193 case spirv::Opcode::OpConstantCompositeReplicateEXT:
194 return processConstantCompositeReplicateEXT(operands);
195 case spirv::Opcode::OpSpecConstantComposite:
196 return processSpecConstantComposite(operands);
197 case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
198 return processSpecConstantCompositeReplicateEXT(operands);
199 case spirv::Opcode::OpSpecConstantOp:
200 return processSpecConstantOperation(operands);
201 case spirv::Opcode::OpConstantTrue:
202 return processConstantBool(
true, operands,
false);
203 case spirv::Opcode::OpSpecConstantTrue:
204 return processConstantBool(
true, operands,
true);
205 case spirv::Opcode::OpConstantFalse:
206 return processConstantBool(
false, operands,
false);
207 case spirv::Opcode::OpSpecConstantFalse:
208 return processConstantBool(
false, operands,
true);
209 case spirv::Opcode::OpConstantNull:
210 return processConstantNull(operands);
211 case spirv::Opcode::OpDecorate:
212 return processDecoration(operands);
213 case spirv::Opcode::OpMemberDecorate:
214 return processMemberDecoration(operands);
215 case spirv::Opcode::OpFunction:
216 return processFunction(operands);
217 case spirv::Opcode::OpLabel:
218 return processLabel(operands);
219 case spirv::Opcode::OpBranch:
220 return processBranch(operands);
221 case spirv::Opcode::OpBranchConditional:
222 return processBranchConditional(operands);
223 case spirv::Opcode::OpSelectionMerge:
224 return processSelectionMerge(operands);
225 case spirv::Opcode::OpLoopMerge:
226 return processLoopMerge(operands);
227 case spirv::Opcode::OpPhi:
228 return processPhi(operands);
229 case spirv::Opcode::OpUndef:
230 return processUndef(operands);
234 return dispatchToAutogenDeserialization(opcode, operands);
237 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
239 unsigned numOperands) {
241 uint32_t valueID = 0;
243 size_t wordIndex = 0;
245 if (wordIndex >= words.size())
247 "expected result type <id> while deserializing for ")
251 auto type =
getType(words[wordIndex]);
253 return emitError(unknownLoc,
"unknown type result <id>: ")
255 resultTypes.push_back(type);
259 if (wordIndex >= words.size())
261 "expected result <id> while deserializing for ")
263 valueID = words[wordIndex];
271 size_t operandIndex = 0;
272 for (; operandIndex < numOperands && wordIndex < words.size();
273 ++operandIndex, ++wordIndex) {
274 auto arg = getValue(words[wordIndex]);
276 return emitError(unknownLoc,
"unknown result <id>: ") << words[wordIndex];
277 operands.push_back(arg);
279 if (operandIndex != numOperands) {
282 "found less operands than expected when deserializing for ")
283 << opName <<
"; only " << operandIndex <<
" of " << numOperands
286 if (wordIndex != words.size()) {
289 "found more operands than expected when deserializing for ")
290 << opName <<
"; only " << wordIndex <<
" of " << words.size()
295 if (decorations.count(valueID)) {
296 auto attrs = decorations[valueID].getAttrs();
297 attributes.append(attrs.begin(), attrs.end());
301 Location loc = createFileLineColLoc(opBuilder);
303 opState.addOperands(operands);
305 opState.addTypes(resultTypes);
306 opState.addAttributes(attributes);
318 if (operands.size() != 2) {
319 return emitError(unknownLoc,
"OpUndef instruction must have two operands");
321 auto type =
getType(operands[0]);
323 return emitError(unknownLoc,
"unknown type <id> with OpUndef instruction");
325 undefMap[operands[1]] = type;
330 if (operands.size() < 4) {
332 "OpExtInst must have at least 4 operands, result type "
333 "<id>, result <id>, set <id> and instruction opcode");
335 if (!extendedInstSets.count(operands[2])) {
336 return emitError(unknownLoc,
"undefined set <id> in OpExtInst");
339 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
340 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
341 return dispatchToExtensionSetAutogenDeserialization(
342 extendedInstSets[operands[2]], operands[3], slicedOperands);
351 unsigned wordIndex = 0;
352 if (wordIndex >= words.size()) {
354 "missing Execution Model specification in OpEntryPoint");
357 context,
static_cast<spirv::ExecutionModel
>(words[wordIndex++]));
358 if (wordIndex >= words.size()) {
359 return emitError(unknownLoc,
"missing <id> in OpEntryPoint");
362 auto fnID = words[wordIndex++];
366 auto parsedFunc = getFunction(fnID);
368 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
370 if (parsedFunc.getName() != fnName) {
374 if (!parsedFunc.getName().starts_with(
"spirv_fn_"))
376 "function name mismatch between OpEntryPoint "
377 "and OpFunction with <id> ")
378 << fnID <<
": " << fnName <<
" vs. " << parsedFunc.getName();
379 parsedFunc.setName(fnName);
382 while (wordIndex < words.size()) {
383 auto arg = getGlobalVariable(words[wordIndex]);
385 return emitError(unknownLoc,
"undefined result <id> ")
386 << words[wordIndex] <<
" while decoding OpEntryPoint";
391 spirv::EntryPointOp::create(
392 opBuilder, unknownLoc, execModel,
394 opBuilder.getArrayAttr(interface));
401 unsigned wordIndex = 0;
402 if (wordIndex >= words.size()) {
404 "missing function result <id> in OpExecutionMode");
407 auto fnID = words[wordIndex++];
408 auto fn = getFunction(fnID);
410 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
413 if (wordIndex >= words.size()) {
414 return emitError(unknownLoc,
"missing Execution Mode in OpExecutionMode");
417 context,
static_cast<spirv::ExecutionMode
>(words[wordIndex++]));
421 while (wordIndex < words.size()) {
422 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
424 auto values = opBuilder.getArrayAttr(attrListElems);
425 spirv::ExecutionModeOp::create(
426 opBuilder, unknownLoc,
435 if (operands.size() < 3) {
437 "OpFunctionCall must have at least 3 operands");
442 return emitError(unknownLoc,
"undefined result type from <id> ")
447 if (isVoidType(resultType))
448 resultType =
nullptr;
450 auto resultID = operands[1];
451 auto functionID = operands[2];
453 auto functionName = getFunctionSymbol(functionID);
456 for (
auto operand : llvm::drop_begin(operands, 3)) {
457 auto value = getValue(operand);
459 return emitError(unknownLoc,
"unknown <id> ")
460 << operand <<
" used by OpFunctionCall";
462 arguments.push_back(value);
465 auto opFunctionCall = spirv::FunctionCallOp::create(
466 opBuilder, unknownLoc, resultType,
470 valueMap[resultID] = opFunctionCall.getResult(0);
478 size_t wordIndex = 0;
482 if (wordIndex < words.size()) {
483 auto arg = getValue(words[wordIndex]);
486 return emitError(unknownLoc,
"unknown result <id> : ")
490 operands.push_back(arg);
494 if (wordIndex < words.size()) {
495 auto arg = getValue(words[wordIndex]);
498 return emitError(unknownLoc,
"unknown result <id> : ")
502 operands.push_back(arg);
506 bool isAlignedAttr =
false;
508 if (wordIndex < words.size()) {
509 auto attrValue = words[wordIndex++];
510 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
511 static_cast<spirv::MemoryAccess
>(attrValue));
512 attributes.push_back(
513 opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
514 isAlignedAttr = (attrValue == 2);
517 if (isAlignedAttr && wordIndex < words.size()) {
518 attributes.push_back(opBuilder.getNamedAttr(
519 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
522 if (wordIndex < words.size()) {
523 auto attrValue = words[wordIndex++];
524 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
525 static_cast<spirv::MemoryAccess
>(attrValue));
526 attributes.push_back(opBuilder.getNamedAttr(
"source_memory_access", attr));
529 if (wordIndex < words.size()) {
530 attributes.push_back(opBuilder.getNamedAttr(
531 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
534 if (wordIndex != words.size()) {
536 "found more operands than expected when deserializing "
537 "spirv::CopyMemoryOp, only ")
538 << wordIndex <<
" of " << words.size() <<
" processed";
541 Location loc = createFileLineColLoc(opBuilder);
542 spirv::CopyMemoryOp::create(opBuilder, loc, resultTypes, operands,
549 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
551 if (words.size() != 4) {
553 "expected 4 words in GenericCastToPtrExplicitOp"
559 uint32_t valueID = 0;
563 return emitError(unknownLoc,
"unknown type result <id> : ") << words[0];
564 resultTypes.push_back(type);
568 auto arg = getValue(words[2]);
570 return emitError(unknownLoc,
"unknown result <id> : ") << words[2];
571 operands.push_back(arg);
573 Location loc = createFileLineColLoc(opBuilder);
574 Operation *op = spirv::GenericCastToPtrExplicitOp::create(
575 opBuilder, loc, resultTypes, operands);
582 #define GET_DESERIALIZATION_FNS
583 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.