20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "spirv-deserialization"
35 return static_cast<spirv::Opcode
>(word & 0xffff);
42 Value spirv::Deserializer::getValue(uint32_t
id) {
43 if (
auto constInfo = getConstant(
id)) {
45 return opBuilder.
create<spirv::ConstantOp>(unknownLoc, constInfo->second,
48 if (
auto varOp = getGlobalVariable(
id)) {
49 auto addressOfOp = opBuilder.
create<spirv::AddressOfOp>(
51 return addressOfOp.getPointer();
53 if (
auto constOp = getSpecConstant(
id)) {
54 auto referenceOfOp = opBuilder.
create<spirv::ReferenceOfOp>(
55 unknownLoc, constOp.getDefaultValue().getType(),
57 return referenceOfOp.getReference();
59 if (
auto constCompositeOp = getSpecConstantComposite(
id)) {
60 auto referenceOfOp = opBuilder.
create<spirv::ReferenceOfOp>(
61 unknownLoc, constCompositeOp.getType(),
63 return referenceOfOp.getReference();
65 if (
auto specConstOperationInfo = getSpecConstantOperation(
id)) {
66 return materializeSpecConstantOperation(
67 id, specConstOperationInfo->enclodesOpcode,
68 specConstOperationInfo->resultTypeID,
69 specConstOperationInfo->enclosedOpOperands);
71 if (
auto undef = getUndefType(
id)) {
72 return opBuilder.
create<spirv::UndefOp>(unknownLoc, undef);
74 return valueMap.lookup(
id);
77 LogicalResult spirv::Deserializer::sliceInstruction(
79 std::optional<spirv::Opcode> expectedOpcode) {
80 auto binarySize = binary.size();
81 if (curOffset >= binarySize) {
83 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
91 uint32_t wordCount = binary[curOffset] >> 16;
94 return emitError(unknownLoc,
"word count cannot be zero");
96 uint32_t nextOffset = curOffset + wordCount;
97 if (nextOffset > binarySize)
98 return emitError(unknownLoc,
"insufficient words for the last instruction");
101 operands = binary.slice(curOffset + 1, wordCount - 1);
102 curOffset = nextOffset;
106 LogicalResult spirv::Deserializer::processInstruction(
108 LLVM_DEBUG(logger.startLine() <<
"[inst] processing instruction "
109 << spirv::stringifyOpcode(opcode) <<
"\n");
114 case spirv::Opcode::OpCapability:
115 return processCapability(operands);
116 case spirv::Opcode::OpExtension:
117 return processExtension(operands);
118 case spirv::Opcode::OpExtInst:
119 return processExtInst(operands);
120 case spirv::Opcode::OpExtInstImport:
121 return processExtInstImport(operands);
122 case spirv::Opcode::OpMemberName:
123 return processMemberName(operands);
124 case spirv::Opcode::OpMemoryModel:
125 return processMemoryModel(operands);
126 case spirv::Opcode::OpEntryPoint:
127 case spirv::Opcode::OpExecutionMode:
128 if (deferInstructions) {
129 deferredInstructions.emplace_back(opcode, operands);
133 case spirv::Opcode::OpVariable:
134 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135 return processGlobalVariable(operands);
138 case spirv::Opcode::OpLine:
139 return processDebugLine(operands);
140 case spirv::Opcode::OpNoLine:
143 case spirv::Opcode::OpName:
144 return processName(operands);
145 case spirv::Opcode::OpString:
146 return processDebugString(operands);
147 case spirv::Opcode::OpModuleProcessed:
148 case spirv::Opcode::OpSource:
149 case spirv::Opcode::OpSourceContinued:
150 case spirv::Opcode::OpSourceExtension:
154 case spirv::Opcode::OpTypeVoid:
155 case spirv::Opcode::OpTypeBool:
156 case spirv::Opcode::OpTypeInt:
157 case spirv::Opcode::OpTypeFloat:
158 case spirv::Opcode::OpTypeVector:
159 case spirv::Opcode::OpTypeMatrix:
160 case spirv::Opcode::OpTypeArray:
161 case spirv::Opcode::OpTypeFunction:
162 case spirv::Opcode::OpTypeImage:
163 case spirv::Opcode::OpTypeSampledImage:
164 case spirv::Opcode::OpTypeRuntimeArray:
165 case spirv::Opcode::OpTypeStruct:
166 case spirv::Opcode::OpTypePointer:
167 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168 return processType(opcode, operands);
169 case spirv::Opcode::OpTypeForwardPointer:
170 return processTypeForwardPointer(operands);
171 case spirv::Opcode::OpConstant:
172 return processConstant(operands,
false);
173 case spirv::Opcode::OpSpecConstant:
174 return processConstant(operands,
true);
175 case spirv::Opcode::OpConstantComposite:
176 return processConstantComposite(operands);
177 case spirv::Opcode::OpSpecConstantComposite:
178 return processSpecConstantComposite(operands);
179 case spirv::Opcode::OpSpecConstantOp:
180 return processSpecConstantOperation(operands);
181 case spirv::Opcode::OpConstantTrue:
182 return processConstantBool(
true, operands,
false);
183 case spirv::Opcode::OpSpecConstantTrue:
184 return processConstantBool(
true, operands,
true);
185 case spirv::Opcode::OpConstantFalse:
186 return processConstantBool(
false, operands,
false);
187 case spirv::Opcode::OpSpecConstantFalse:
188 return processConstantBool(
false, operands,
true);
189 case spirv::Opcode::OpConstantNull:
190 return processConstantNull(operands);
191 case spirv::Opcode::OpDecorate:
192 return processDecoration(operands);
193 case spirv::Opcode::OpMemberDecorate:
194 return processMemberDecoration(operands);
195 case spirv::Opcode::OpFunction:
196 return processFunction(operands);
197 case spirv::Opcode::OpLabel:
198 return processLabel(operands);
199 case spirv::Opcode::OpBranch:
200 return processBranch(operands);
201 case spirv::Opcode::OpBranchConditional:
202 return processBranchConditional(operands);
203 case spirv::Opcode::OpSelectionMerge:
204 return processSelectionMerge(operands);
205 case spirv::Opcode::OpLoopMerge:
206 return processLoopMerge(operands);
207 case spirv::Opcode::OpPhi:
208 return processPhi(operands);
209 case spirv::Opcode::OpUndef:
210 return processUndef(operands);
214 return dispatchToAutogenDeserialization(opcode, operands);
217 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
219 unsigned numOperands) {
221 uint32_t valueID = 0;
223 size_t wordIndex = 0;
225 if (wordIndex >= words.size())
227 "expected result type <id> while deserializing for ")
231 auto type =
getType(words[wordIndex]);
233 return emitError(unknownLoc,
"unknown type result <id>: ")
235 resultTypes.push_back(type);
239 if (wordIndex >= words.size())
241 "expected result <id> while deserializing for ")
243 valueID = words[wordIndex];
251 size_t operandIndex = 0;
252 for (; operandIndex < numOperands && wordIndex < words.size();
253 ++operandIndex, ++wordIndex) {
254 auto arg = getValue(words[wordIndex]);
256 return emitError(unknownLoc,
"unknown result <id>: ") << words[wordIndex];
257 operands.push_back(arg);
259 if (operandIndex != numOperands) {
262 "found less operands than expected when deserializing for ")
263 << opName <<
"; only " << operandIndex <<
" of " << numOperands
266 if (wordIndex != words.size()) {
269 "found more operands than expected when deserializing for ")
270 << opName <<
"; only " << wordIndex <<
" of " << words.size()
275 if (decorations.count(valueID)) {
276 auto attrs = decorations[valueID].getAttrs();
277 attributes.append(attrs.begin(), attrs.end());
281 Location loc = createFileLineColLoc(opBuilder);
283 opState.addOperands(operands);
285 opState.addTypes(resultTypes);
286 opState.addAttributes(attributes);
298 if (operands.size() != 2) {
299 return emitError(unknownLoc,
"OpUndef instruction must have two operands");
301 auto type =
getType(operands[0]);
303 return emitError(unknownLoc,
"unknown type <id> with OpUndef instruction");
305 undefMap[operands[1]] = type;
310 if (operands.size() < 4) {
312 "OpExtInst must have at least 4 operands, result type "
313 "<id>, result <id>, set <id> and instruction opcode");
315 if (!extendedInstSets.count(operands[2])) {
316 return emitError(unknownLoc,
"undefined set <id> in OpExtInst");
319 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
320 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
321 return dispatchToExtensionSetAutogenDeserialization(
322 extendedInstSets[operands[2]], operands[3], slicedOperands);
331 unsigned wordIndex = 0;
332 if (wordIndex >= words.size()) {
334 "missing Execution Model specification in OpEntryPoint");
337 context,
static_cast<spirv::ExecutionModel
>(words[wordIndex++]));
338 if (wordIndex >= words.size()) {
339 return emitError(unknownLoc,
"missing <id> in OpEntryPoint");
342 auto fnID = words[wordIndex++];
346 auto parsedFunc = getFunction(fnID);
348 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
350 if (parsedFunc.getName() != fnName) {
354 if (!parsedFunc.getName().starts_with(
"spirv_fn_"))
356 "function name mismatch between OpEntryPoint "
357 "and OpFunction with <id> ")
358 << fnID <<
": " << fnName <<
" vs. " << parsedFunc.getName();
359 parsedFunc.setName(fnName);
362 while (wordIndex < words.size()) {
363 auto arg = getGlobalVariable(words[wordIndex]);
365 return emitError(unknownLoc,
"undefined result <id> ")
366 << words[wordIndex] <<
" while decoding OpEntryPoint";
371 opBuilder.create<spirv::EntryPointOp>(
373 opBuilder.getArrayAttr(interface));
380 unsigned wordIndex = 0;
381 if (wordIndex >= words.size()) {
383 "missing function result <id> in OpExecutionMode");
386 auto fnID = words[wordIndex++];
387 auto fn = getFunction(fnID);
389 return emitError(unknownLoc,
"no function matching <id> ") << fnID;
392 if (wordIndex >= words.size()) {
393 return emitError(unknownLoc,
"missing Execution Mode in OpExecutionMode");
396 context,
static_cast<spirv::ExecutionMode
>(words[wordIndex++]));
400 while (wordIndex < words.size()) {
401 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
403 auto values = opBuilder.getArrayAttr(attrListElems);
404 opBuilder.create<spirv::ExecutionModeOp>(
413 if (operands.size() < 3) {
415 "OpFunctionCall must have at least 3 operands");
420 return emitError(unknownLoc,
"undefined result type from <id> ")
425 if (isVoidType(resultType))
426 resultType =
nullptr;
428 auto resultID = operands[1];
429 auto functionID = operands[2];
431 auto functionName = getFunctionSymbol(functionID);
434 for (
auto operand : llvm::drop_begin(operands, 3)) {
435 auto value = getValue(operand);
437 return emitError(unknownLoc,
"unknown <id> ")
438 << operand <<
" used by OpFunctionCall";
440 arguments.push_back(value);
443 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
444 unknownLoc, resultType,
448 valueMap[resultID] = opFunctionCall.getResult(0);
456 size_t wordIndex = 0;
460 if (wordIndex < words.size()) {
461 auto arg = getValue(words[wordIndex]);
464 return emitError(unknownLoc,
"unknown result <id> : ")
468 operands.push_back(arg);
472 if (wordIndex < words.size()) {
473 auto arg = getValue(words[wordIndex]);
476 return emitError(unknownLoc,
"unknown result <id> : ")
480 operands.push_back(arg);
484 bool isAlignedAttr =
false;
486 if (wordIndex < words.size()) {
487 auto attrValue = words[wordIndex++];
488 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
489 static_cast<spirv::MemoryAccess
>(attrValue));
490 attributes.push_back(
491 opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
492 isAlignedAttr = (attrValue == 2);
495 if (isAlignedAttr && wordIndex < words.size()) {
496 attributes.push_back(opBuilder.getNamedAttr(
497 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
500 if (wordIndex < words.size()) {
501 auto attrValue = words[wordIndex++];
502 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
503 static_cast<spirv::MemoryAccess
>(attrValue));
504 attributes.push_back(opBuilder.getNamedAttr(
"source_memory_access", attr));
507 if (wordIndex < words.size()) {
508 attributes.push_back(opBuilder.getNamedAttr(
509 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
512 if (wordIndex != words.size()) {
514 "found more operands than expected when deserializing "
515 "spirv::CopyMemoryOp, only ")
516 << wordIndex <<
" of " << words.size() <<
" processed";
519 Location loc = createFileLineColLoc(opBuilder);
520 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
526 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
528 if (words.size() != 4) {
530 "expected 4 words in GenericCastToPtrExplicitOp"
536 uint32_t valueID = 0;
540 return emitError(unknownLoc,
"unknown type result <id> : ") << words[0];
541 resultTypes.push_back(type);
545 auto arg = getValue(words[2]);
547 return emitError(unknownLoc,
"unknown result <id> : ") << words[2];
548 operands.push_back(arg);
550 Location loc = createFileLineColLoc(opBuilder);
552 loc, resultTypes, operands);
559 #define GET_DESERIALIZATION_FNS
560 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.