20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "spirv-deserialization"
35 return static_cast<spirv::Opcode
>(word & 0xffff);
42 Value spirv::Deserializer::getValue(uint32_t
id) {
43 if (
auto constInfo = getConstant(
id)) {
45 return spirv::ConstantOp::create(opBuilder, unknownLoc, constInfo->second,
48 if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
49 getConstantCompositeReplicate(
id)) {
50 return spirv::EXTConstantCompositeReplicateOp::create(
51 opBuilder, unknownLoc, constCompositeReplicateInfo->second,
52 constCompositeReplicateInfo->first);
54 if (
auto varOp = getGlobalVariable(
id)) {
56 spirv::AddressOfOp::create(opBuilder, unknownLoc, varOp.getType(),
58 return addressOfOp.getPointer();
60 if (
auto constOp = getSpecConstant(
id)) {
61 auto referenceOfOp = spirv::ReferenceOfOp::create(
62 opBuilder, unknownLoc, constOp.getDefaultValue().getType(),
64 return referenceOfOp.getReference();
66 if (SpecConstantCompositeOp specConstCompositeOp =
67 getSpecConstantComposite(
id)) {
68 auto referenceOfOp = spirv::ReferenceOfOp::create(
69 opBuilder, unknownLoc, specConstCompositeOp.getType(),
71 return referenceOfOp.getReference();
73 if (
auto specConstCompositeReplicateOp =
74 getSpecConstantCompositeReplicate(
id)) {
75 auto referenceOfOp = spirv::ReferenceOfOp::create(
76 opBuilder, unknownLoc, specConstCompositeReplicateOp.getType(),
78 return referenceOfOp.getReference();
80 if (
auto specConstOperationInfo = getSpecConstantOperation(
id)) {
81 return materializeSpecConstantOperation(
82 id, specConstOperationInfo->enclodesOpcode,
83 specConstOperationInfo->resultTypeID,
84 specConstOperationInfo->enclosedOpOperands);
86 if (
auto undef = getUndefType(
id)) {
87 return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
89 if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
90 graphConstantARMInfo = getGraphConstantARM(
id)) {
91 IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
92 Type resultType = graphConstantARMInfo->resultType;
93 return spirv::GraphConstantARMOp::create(opBuilder, unknownLoc, resultType,
96 return valueMap.lookup(
id);
99 LogicalResult spirv::Deserializer::sliceInstruction(
101 std::optional<spirv::Opcode> expectedOpcode) {
102 auto binarySize = binary.size();
103 if (curOffset >= binarySize) {
104 return emitError(unknownLoc,
"expected ")
105 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
113 uint32_t wordCount = binary[curOffset] >> 16;
116 return emitError(unknownLoc,
"word count cannot be zero");
118 uint32_t nextOffset = curOffset + wordCount;
119 if (nextOffset > binarySize)
120 return emitError(unknownLoc,
"insufficient words for the last instruction");
123 operands = binary.slice(curOffset + 1, wordCount - 1);
124 curOffset = nextOffset;
128 LogicalResult spirv::Deserializer::processInstruction(
130 LLVM_DEBUG(logger.startLine() <<
"[inst] processing instruction "
131 << spirv::stringifyOpcode(opcode) <<
"\n");
136 case spirv::Opcode::OpCapability:
137 return processCapability(operands);
138 case spirv::Opcode::OpExtension:
139 return processExtension(operands);
140 case spirv::Opcode::OpExtInst:
141 return processExtInst(operands);
142 case spirv::Opcode::OpExtInstImport:
143 return processExtInstImport(operands);
144 case spirv::Opcode::OpMemberName:
145 return processMemberName(operands);
146 case spirv::Opcode::OpMemoryModel:
147 return processMemoryModel(operands);
148 case spirv::Opcode::OpEntryPoint:
149 case spirv::Opcode::OpExecutionMode:
150 if (deferInstructions) {
151 deferredInstructions.emplace_back(opcode, operands);
155 case spirv::Opcode::OpVariable:
156 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
157 return processGlobalVariable(operands);
160 case spirv::Opcode::OpLine:
161 return processDebugLine(operands);
162 case spirv::Opcode::OpNoLine:
165 case spirv::Opcode::OpName:
166 return processName(operands);
167 case spirv::Opcode::OpString:
168 return processDebugString(operands);
169 case spirv::Opcode::OpModuleProcessed:
170 case spirv::Opcode::OpSource:
171 case spirv::Opcode::OpSourceContinued:
172 case spirv::Opcode::OpSourceExtension:
176 case spirv::Opcode::OpTypeVoid:
177 case spirv::Opcode::OpTypeBool:
178 case spirv::Opcode::OpTypeInt:
179 case spirv::Opcode::OpTypeFloat:
180 case spirv::Opcode::OpTypeVector:
181 case spirv::Opcode::OpTypeMatrix:
182 case spirv::Opcode::OpTypeArray:
183 case spirv::Opcode::OpTypeFunction:
184 case spirv::Opcode::OpTypeImage:
185 case spirv::Opcode::OpTypeSampledImage:
186 case spirv::Opcode::OpTypeRuntimeArray:
187 case spirv::Opcode::OpTypeStruct:
188 case spirv::Opcode::OpTypePointer:
189 case spirv::Opcode::OpTypeTensorARM:
190 case spirv::Opcode::OpTypeGraphARM:
191 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
192 return processType(opcode, operands);
193 case spirv::Opcode::OpTypeForwardPointer:
194 return processTypeForwardPointer(operands);
195 case spirv::Opcode::OpConstant:
196 return processConstant(operands,
false);
197 case spirv::Opcode::OpSpecConstant:
198 return processConstant(operands,
true);
199 case spirv::Opcode::OpConstantComposite:
200 return processConstantComposite(operands);
201 case spirv::Opcode::OpConstantCompositeReplicateEXT:
202 return processConstantCompositeReplicateEXT(operands);
203 case spirv::Opcode::OpSpecConstantComposite:
204 return processSpecConstantComposite(operands);
205 case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
206 return processSpecConstantCompositeReplicateEXT(operands);
207 case spirv::Opcode::OpSpecConstantOp:
208 return processSpecConstantOperation(operands);
209 case spirv::Opcode::OpConstantTrue:
210 return processConstantBool(
true, operands,
false);
211 case spirv::Opcode::OpSpecConstantTrue:
212 return processConstantBool(
true, operands,
true);
213 case spirv::Opcode::OpConstantFalse:
214 return processConstantBool(
false, operands,
false);
215 case spirv::Opcode::OpSpecConstantFalse:
216 return processConstantBool(
false, operands,
true);
217 case spirv::Opcode::OpConstantNull:
218 return processConstantNull(operands);
219 case spirv::Opcode::OpGraphConstantARM:
220 return processGraphConstantARM(operands);
221 case spirv::Opcode::OpDecorate:
222 return processDecoration(operands);
223 case spirv::Opcode::OpMemberDecorate:
224 return processMemberDecoration(operands);
225 case spirv::Opcode::OpFunction:
226 return processFunction(operands);
227 case spirv::Opcode::OpGraphEntryPointARM:
228 if (deferInstructions) {
229 deferredInstructions.emplace_back(opcode, operands);
232 return processGraphEntryPointARM(operands);
233 case spirv::Opcode::OpGraphARM:
234 return processGraphARM(operands);
235 case spirv::Opcode::OpGraphSetOutputARM:
236 return processOpGraphSetOutputARM(operands);
237 case spirv::Opcode::OpGraphEndARM:
238 return processGraphEndARM(operands);
239 case spirv::Opcode::OpLabel:
240 return processLabel(operands);
241 case spirv::Opcode::OpBranch:
242 return processBranch(operands);
243 case spirv::Opcode::OpBranchConditional:
244 return processBranchConditional(operands);
245 case spirv::Opcode::OpSelectionMerge:
246 return processSelectionMerge(operands);
247 case spirv::Opcode::OpLoopMerge:
248 return processLoopMerge(operands);
249 case spirv::Opcode::OpPhi:
250 return processPhi(operands);
251 case spirv::Opcode::OpUndef:
252 return processUndef(operands);
256 return dispatchToAutogenDeserialization(opcode, operands);
259 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
261 unsigned numOperands) {
263 uint32_t valueID = 0;
265 size_t wordIndex = 0;
267 if (wordIndex >= words.size())
269 "expected result type <id> while deserializing for ")
273 auto type =
getType(words[wordIndex]);
275 return emitError(unknownLoc,
"unknown type result <id>: ")
277 resultTypes.push_back(type);
281 if (wordIndex >= words.size())
283 "expected result <id> while deserializing for ")
285 valueID = words[wordIndex];
293 size_t operandIndex = 0;
294 for (; operandIndex < numOperands && wordIndex < words.size();
295 ++operandIndex, ++wordIndex) {
296 auto arg = getValue(words[wordIndex]);
298 return emitError(unknownLoc,
"unknown result <id>: ") << words[wordIndex];
299 operands.push_back(arg);
301 if (operandIndex != numOperands) {
304 "found less operands than expected when deserializing for ")
305 << opName <<
"; only " << operandIndex <<
" of " << numOperands
308 if (wordIndex != words.size()) {
311 "found more operands than expected when deserializing for ")
312 << opName <<
"; only " << wordIndex <<
" of " << words.size()
317 if (decorations.count(valueID)) {
318 auto attrs = decorations[valueID].getAttrs();
319 attributes.append(attrs.begin(), attrs.end());
323 Location loc = createFileLineColLoc(opBuilder);
325 opState.addOperands(operands);
327 opState.addTypes(resultTypes);
328 opState.addAttributes(attributes);
340 if (operands.size() != 2) {
341 return emitError(unknownLoc,
"OpUndef instruction must have two operands");
343 auto type =
getType(operands[0]);
345 return emitError(unknownLoc,
"unknown type <id> with OpUndef instruction");
347 undefMap[operands[1]] = type;
352 if (operands.size() < 4) {
354 "OpExtInst must have at least 4 operands, result type "
355 "<id>, result <id>, set <id> and instruction opcode");
357 if (!extendedInstSets.count(operands[2])) {
358 return emitError(unknownLoc,
"undefined set <id> in OpExtInst");
361 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
362 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
363 return dispatchToExtensionSetAutogenDeserialization(
364 extendedInstSets[operands[2]], operands[3], slicedOperands);
373 unsigned wordIndex = 0;
374 if (wordIndex >= words.size()) {
376 "missing Execution Model specification in OpEntryPoint");
379 context,
static_cast<spirv::ExecutionModel
>(words[wordIndex++]));
380 if (wordIndex >= words.size()) {
381 return emitError(unknownLoc,
"missing <id> in OpEntryPoint");
384 auto fnID = words[wordIndex++];
388 auto parsedFunc = getFunction(fnID);
390 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
392 if (parsedFunc.getName() != fnName) {
396 if (!parsedFunc.getName().starts_with(
"spirv_fn_"))
398 "function name mismatch between OpEntryPoint "
399 "and OpFunction with <id> ")
400 << fnID <<
": " << fnName <<
" vs. " << parsedFunc.getName();
401 parsedFunc.setName(fnName);
404 while (wordIndex < words.size()) {
405 auto arg = getGlobalVariable(words[wordIndex]);
407 return emitError(unknownLoc,
"undefined result <id> ")
408 << words[wordIndex] <<
" while decoding OpEntryPoint";
413 spirv::EntryPointOp::create(
414 opBuilder, unknownLoc, execModel,
416 opBuilder.getArrayAttr(interface));
423 unsigned wordIndex = 0;
424 if (wordIndex >= words.size()) {
426 "missing function result <id> in OpExecutionMode");
429 auto fnID = words[wordIndex++];
430 auto fn = getFunction(fnID);
432 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
435 if (wordIndex >= words.size()) {
436 return emitError(unknownLoc,
"missing Execution Mode in OpExecutionMode");
439 context,
static_cast<spirv::ExecutionMode
>(words[wordIndex++]));
443 while (wordIndex < words.size()) {
444 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
446 auto values = opBuilder.getArrayAttr(attrListElems);
447 spirv::ExecutionModeOp::create(
448 opBuilder, unknownLoc,
457 if (operands.size() < 3) {
459 "OpFunctionCall must have at least 3 operands");
464 return emitError(unknownLoc,
"undefined result type from <id> ")
469 if (isVoidType(resultType))
470 resultType =
nullptr;
472 auto resultID = operands[1];
473 auto functionID = operands[2];
475 auto functionName = getFunctionSymbol(functionID);
478 for (
auto operand : llvm::drop_begin(operands, 3)) {
479 auto value = getValue(operand);
481 return emitError(unknownLoc,
"unknown <id> ")
482 << operand <<
" used by OpFunctionCall";
484 arguments.push_back(value);
487 auto opFunctionCall = spirv::FunctionCallOp::create(
488 opBuilder, unknownLoc, resultType,
492 valueMap[resultID] = opFunctionCall.getResult(0);
500 size_t wordIndex = 0;
504 if (wordIndex < words.size()) {
505 auto arg = getValue(words[wordIndex]);
508 return emitError(unknownLoc,
"unknown result <id> : ")
512 operands.push_back(arg);
516 if (wordIndex < words.size()) {
517 auto arg = getValue(words[wordIndex]);
520 return emitError(unknownLoc,
"unknown result <id> : ")
524 operands.push_back(arg);
528 bool isAlignedAttr =
false;
530 if (wordIndex < words.size()) {
531 auto attrValue = words[wordIndex++];
532 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
533 static_cast<spirv::MemoryAccess
>(attrValue));
534 attributes.push_back(
535 opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
536 isAlignedAttr = (attrValue == 2);
539 if (isAlignedAttr && wordIndex < words.size()) {
540 attributes.push_back(opBuilder.getNamedAttr(
541 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
544 if (wordIndex < words.size()) {
545 auto attrValue = words[wordIndex++];
546 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
547 static_cast<spirv::MemoryAccess
>(attrValue));
548 attributes.push_back(opBuilder.getNamedAttr(
"source_memory_access", attr));
551 if (wordIndex < words.size()) {
552 attributes.push_back(opBuilder.getNamedAttr(
553 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
556 if (wordIndex != words.size()) {
558 "found more operands than expected when deserializing "
559 "spirv::CopyMemoryOp, only ")
560 << wordIndex <<
" of " << words.size() <<
" processed";
563 Location loc = createFileLineColLoc(opBuilder);
564 spirv::CopyMemoryOp::create(opBuilder, loc, resultTypes, operands,
571 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
573 if (words.size() != 4) {
575 "expected 4 words in GenericCastToPtrExplicitOp"
581 uint32_t valueID = 0;
585 return emitError(unknownLoc,
"unknown type result <id> : ") << words[0];
586 resultTypes.push_back(type);
590 auto arg = getValue(words[2]);
592 return emitError(unknownLoc,
"unknown result <id> : ") << words[2];
593 operands.push_back(arg);
595 Location loc = createFileLineColLoc(opBuilder);
596 Operation *op = spirv::GenericCastToPtrExplicitOp::create(
597 opBuilder, loc, resultTypes, operands);
604 #define GET_DESERIALIZATION_FNS
605 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.