MLIR 23.0.0git
DeserializeOps.cpp
Go to the documentation of this file.
1//===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Deserializer methods for SPIR-V binary instructions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/Location.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/Support/Debug.h"
24#include <optional>
25
26using namespace mlir;
27
28#define DEBUG_TYPE "spirv-deserialization"
29
30//===----------------------------------------------------------------------===//
31// Utility Functions
32//===----------------------------------------------------------------------===//
33
34/// Extracts the opcode from the given first word of a SPIR-V instruction.
35static inline spirv::Opcode extractOpcode(uint32_t word) {
36 return static_cast<spirv::Opcode>(word & 0xffff);
37}
38
39/// Returns a NameLoc location from the given debug info string.
40static NameLoc getLocFromDebugInfoString(OpBuilder &builder, StringRef source) {
41 return NameLoc::get(builder.getStringAttr(source));
42}
43
44//===----------------------------------------------------------------------===//
45// Instruction
46//===----------------------------------------------------------------------===//
47
49 if (auto constInfo = getConstant(id)) {
50 // Materialize a `spirv.Constant` op at every use site.
51 Location loc = unknownLoc;
52 if (LocationAttr locAttr = constantLocMap.lookup(id))
53 loc = Location(locAttr);
54 return spirv::ConstantOp::create(opBuilder, loc, constInfo->second,
55 constInfo->first);
56 }
57 if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
59 return spirv::EXTConstantCompositeReplicateOp::create(
60 opBuilder, unknownLoc, constCompositeReplicateInfo->second,
61 constCompositeReplicateInfo->first);
62 }
63 if (auto varOp = getGlobalVariable(id)) {
64 auto addressOfOp =
65 spirv::AddressOfOp::create(opBuilder, unknownLoc, varOp.getType(),
66 SymbolRefAttr::get(varOp.getOperation()));
67 return addressOfOp.getPointer();
68 }
69 if (auto constOp = getSpecConstant(id)) {
70 auto referenceOfOp = spirv::ReferenceOfOp::create(
71 opBuilder, unknownLoc, constOp.getDefaultValue().getType(),
72 SymbolRefAttr::get(constOp.getOperation()));
73 return referenceOfOp.getReference();
74 }
75 if (SpecConstantCompositeOp specConstCompositeOp =
77 auto referenceOfOp = spirv::ReferenceOfOp::create(
78 opBuilder, unknownLoc, specConstCompositeOp.getType(),
79 SymbolRefAttr::get(specConstCompositeOp.getOperation()));
80 return referenceOfOp.getReference();
81 }
82 if (auto specConstCompositeReplicateOp =
84 auto referenceOfOp = spirv::ReferenceOfOp::create(
85 opBuilder, unknownLoc, specConstCompositeReplicateOp.getType(),
86 SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
87 return referenceOfOp.getReference();
88 }
89 if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
91 id, specConstOperationInfo->enclodesOpcode,
92 specConstOperationInfo->resultTypeID,
93 specConstOperationInfo->enclosedOpOperands);
94 }
95 if (auto undef = getUndefType(id)) {
96 return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
97 }
98 if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
99 graphConstantARMInfo = getGraphConstantARM(id)) {
100 IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
101 Type resultType = graphConstantARMInfo->resultType;
102 return spirv::GraphConstantARMOp::create(opBuilder, unknownLoc, resultType,
103 graphConstantID);
104 }
105 return valueMap.lookup(id);
106}
107
109 spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
110 std::optional<spirv::Opcode> expectedOpcode) {
111 auto binarySize = binary.size();
112 if (curOffset >= binarySize) {
113 return emitError(unknownLoc, "expected ")
114 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
115 : "more")
116 << " instruction";
117 }
118
119 // For each instruction, get its word count from the first word to slice it
120 // from the stream properly, and then dispatch to the instruction handler.
121
122 uint32_t wordCount = binary[curOffset] >> 16;
123
124 if (wordCount == 0)
125 return emitError(unknownLoc, "word count cannot be zero");
126
127 uint32_t nextOffset = curOffset + wordCount;
128 if (nextOffset > binarySize)
129 return emitError(unknownLoc, "insufficient words for the last instruction");
130
131 opcode = extractOpcode(binary[curOffset]);
132 operands = binary.slice(curOffset + 1, wordCount - 1);
133 curOffset = nextOffset;
134 return success();
135}
136
138 spirv::Opcode opcode, ArrayRef<uint32_t> &operands,
139 SmallVectorImpl<uint32_t> &mergedStorage) {
140 std::optional<spirv::Opcode> continuationOp = getContinuationOpcode(opcode);
141 if (!continuationOp)
142 return;
143
144 size_t binarySize = binary.size();
145 auto isNextContinuation = [&]() {
146 if (curOffset >= binarySize)
147 return false;
148 uint32_t wordCount = binary[curOffset] >> 16;
149 if (wordCount == 0 || curOffset + wordCount > binarySize)
150 return false;
151 return extractOpcode(binary[curOffset]) == *continuationOp;
152 };
153
154 if (!isNextContinuation())
155 return;
156
157 mergedStorage.assign(operands);
158 do {
159 spirv::Opcode contOpcode;
160 ArrayRef<uint32_t> contOperands;
161 if (failed(sliceInstruction(contOpcode, contOperands, *continuationOp)))
162 return;
163 llvm::append_range(mergedStorage, contOperands);
164 } while (isNextContinuation());
165 operands = mergedStorage;
166}
167
169 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
170 LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
171 << spirv::stringifyOpcode(opcode) << "\n");
172
173 SmallVector<uint32_t, 0> mergedStorage;
174 mergeLongCompositeContinuations(opcode, operands, mergedStorage);
175
176 // First dispatch all the instructions whose opcode does not correspond to
177 // those that have a direct mirror in the SPIR-V dialect
178 switch (opcode) {
179 case spirv::Opcode::OpCapability:
180 return processCapability(operands);
181 case spirv::Opcode::OpExtension:
182 return processExtension(operands);
183 case spirv::Opcode::OpExtInst: {
185 operands.size() >= 4 ? extendedInstSets.find(operands[2])
186 : extendedInstSets.end();
187 if (setIt != extendedInstSets.end() && setIt->second == extDebugInfo)
188 return processDebugInfoExtInst(operands, deferInstructions);
189 return processExtInst(operands);
190 }
191 case spirv::Opcode::OpExtInstImport:
192 return processExtInstImport(operands);
193 case spirv::Opcode::OpMemberName:
194 return processMemberName(operands);
195 case spirv::Opcode::OpMemoryModel:
196 return processMemoryModel(operands);
197 case spirv::Opcode::OpEntryPoint:
198 case spirv::Opcode::OpExecutionMode:
199 case spirv::Opcode::OpExecutionModeId:
200 if (deferInstructions) {
201 deferredInstructions.emplace_back(opcode, operands);
202 return success();
203 }
204 break;
205 case spirv::Opcode::OpVariable:
206 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
207 return processGlobalVariable(operands);
208 }
209 break;
210 case spirv::Opcode::OpLine:
211 return processDebugLine(operands);
212 case spirv::Opcode::OpNoLine:
214 return success();
215 case spirv::Opcode::OpName:
216 return processName(operands);
217 case spirv::Opcode::OpString:
218 return processDebugString(operands);
219 case spirv::Opcode::OpModuleProcessed:
220 case spirv::Opcode::OpSource:
221 case spirv::Opcode::OpSourceContinued:
222 case spirv::Opcode::OpSourceExtension:
223 // TODO: This is debug information embedded in the binary which should be
224 // translated into the spirv.module.
225 return success();
226 case spirv::Opcode::OpTypeVoid:
227 case spirv::Opcode::OpTypeBool:
228 case spirv::Opcode::OpTypeInt:
229 case spirv::Opcode::OpTypeFloat:
230 case spirv::Opcode::OpTypeVector:
231 case spirv::Opcode::OpTypeMatrix:
232 case spirv::Opcode::OpTypeArray:
233 case spirv::Opcode::OpTypeFunction:
234 case spirv::Opcode::OpTypeImage:
235 case spirv::Opcode::OpTypeSampler:
236 case spirv::Opcode::OpTypeNamedBarrier:
237 case spirv::Opcode::OpTypeSampledImage:
238 case spirv::Opcode::OpTypeRuntimeArray:
239 case spirv::Opcode::OpTypeStruct:
240 case spirv::Opcode::OpTypePointer:
241 case spirv::Opcode::OpTypeTensorARM:
242 case spirv::Opcode::OpTypeGraphARM:
243 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
244 return processType(opcode, operands);
245 case spirv::Opcode::OpTypeForwardPointer:
246 return processTypeForwardPointer(operands);
247 case spirv::Opcode::OpConstant:
248 return processConstant(operands, /*isSpec=*/false);
249 case spirv::Opcode::OpSpecConstant:
250 return processConstant(operands, /*isSpec=*/true);
251 case spirv::Opcode::OpConstantComposite:
252 return processConstantComposite(operands);
253 case spirv::Opcode::OpConstantCompositeReplicateEXT:
255 case spirv::Opcode::OpSpecConstantComposite:
256 return processSpecConstantComposite(operands);
257 case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
259 case spirv::Opcode::OpSpecConstantOp:
260 return processSpecConstantOperation(operands);
261 case spirv::Opcode::OpConstantTrue:
262 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
263 case spirv::Opcode::OpSpecConstantTrue:
264 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
265 case spirv::Opcode::OpConstantFalse:
266 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
267 case spirv::Opcode::OpSpecConstantFalse:
268 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
269 case spirv::Opcode::OpConstantNull:
270 return processConstantNull(operands);
271 case spirv::Opcode::OpGraphConstantARM:
272 return processGraphConstantARM(operands);
273 case spirv::Opcode::OpDecorate:
274 case spirv::Opcode::OpDecorateId:
275 return processDecoration(operands);
276 case spirv::Opcode::OpMemberDecorate:
277 return processMemberDecoration(operands);
278 case spirv::Opcode::OpFunction:
279 return processFunction(operands);
280 case spirv::Opcode::OpGraphEntryPointARM:
281 if (deferInstructions) {
282 deferredInstructions.emplace_back(opcode, operands);
283 return success();
284 }
285 return processGraphEntryPointARM(operands);
286 case spirv::Opcode::OpGraphARM:
287 return processGraphARM(operands);
288 case spirv::Opcode::OpGraphSetOutputARM:
289 return processOpGraphSetOutputARM(operands);
290 case spirv::Opcode::OpGraphEndARM:
291 return processGraphEndARM(operands);
292 case spirv::Opcode::OpLabel:
293 return processLabel(operands);
294 case spirv::Opcode::OpBranch:
295 return processBranch(operands);
296 case spirv::Opcode::OpBranchConditional:
297 return processBranchConditional(operands);
298 case spirv::Opcode::OpSelectionMerge:
299 return processSelectionMerge(operands);
300 case spirv::Opcode::OpLoopMerge:
301 return processLoopMerge(operands);
302 case spirv::Opcode::OpPhi:
303 return processPhi(operands);
304 case spirv::Opcode::OpSwitch:
305 return processSwitch(operands);
306 case spirv::Opcode::OpUndef:
307 return processUndef(operands);
308 default:
309 break;
310 }
311 return dispatchToAutogenDeserialization(opcode, operands);
312}
313
315 ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
316 unsigned numOperands) {
317 SmallVector<Type, 1> resultTypes;
318 uint32_t valueID = 0;
319
320 size_t wordIndex = 0;
321 if (hasResult) {
322 if (wordIndex >= words.size())
323 return emitError(unknownLoc,
324 "expected result type <id> while deserializing for ")
325 << opName;
326
327 // Decode the type <id>
328 auto type = getType(words[wordIndex]);
329 if (!type)
330 return emitError(unknownLoc, "unknown type result <id>: ")
331 << words[wordIndex];
332 resultTypes.push_back(type);
333 ++wordIndex;
334
335 // Decode the result <id>
336 if (wordIndex >= words.size())
337 return emitError(unknownLoc,
338 "expected result <id> while deserializing for ")
339 << opName;
340 valueID = words[wordIndex];
341 ++wordIndex;
342 }
343
344 SmallVector<Value, 4> operands;
346
347 // Decode operands
348 size_t operandIndex = 0;
349 for (; operandIndex < numOperands && wordIndex < words.size();
350 ++operandIndex, ++wordIndex) {
351 auto arg = getValue(words[wordIndex]);
352 if (!arg)
353 return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
354 operands.push_back(arg);
355 }
356 if (operandIndex != numOperands) {
357 return emitError(
358 unknownLoc,
359 "found less operands than expected when deserializing for ")
360 << opName << "; only " << operandIndex << " of " << numOperands
361 << " processed";
362 }
363 if (wordIndex != words.size()) {
364 return emitError(
365 unknownLoc,
366 "found more operands than expected when deserializing for ")
367 << opName << "; only " << wordIndex << " of " << words.size()
368 << " processed";
369 }
370
371 // Attach attributes from decorations
372 if (decorations.count(valueID)) {
373 auto attrs = decorations[valueID].getAttrs();
374 attributes.append(attrs.begin(), attrs.end());
375 }
376
377 // Create the op and update bookkeeping maps
378 Location loc = createFileLineColLoc(opBuilder);
379 OperationState opState(loc, opName);
380 opState.addOperands(operands);
381 if (hasResult)
382 opState.addTypes(resultTypes);
383 opState.addAttributes(attributes);
384 Operation *op = opBuilder.create(opState);
385 if (hasResult)
386 valueMap[valueID] = op->getResult(0);
387
390
391 return success();
392}
393
395 if (operands.size() != 2) {
396 return emitError(unknownLoc, "OpUndef instruction must have two operands");
397 }
398 auto type = getType(operands[0]);
399 if (!type) {
400 return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
401 }
402 undefMap[operands[1]] = type;
403 return success();
404}
405
406LogicalResult
408 bool deferInstructions) {
409 if (deferInstructions) {
410 deferredInstructions.emplace_back(spirv::Opcode::OpExtInst, operands);
411 return success();
412 }
413
414 if (operands.size() < 4) {
415 return emitError(unknownLoc,
416 "OpExtInst must have at least 4 operands, result type "
417 "<id>, result <id>, set <id> and instruction opcode");
418 }
419
420 Type resultType = getType(operands[0]);
421 if (!resultType || !isVoidType(resultType))
422 return emitError(unknownLoc,
423 "DebugInfo instructions must have OpTypeVoid result type");
424
425 auto getDebugLoc = [&](uint32_t stringID) -> FailureOr<Location> {
427 debugInfoMap.find(stringID);
428 if (stringIt == debugInfoMap.end()) {
429 return emitError(unknownLoc, "undefined string <id> ")
430 << stringID << " in DebugInfo";
431 }
432 return Location(getLocFromDebugInfoString(opBuilder, stringIt->second));
433 };
434
435 if (!spirv::isValidGraphDebugInfoExtInst(operands[3]))
436 return emitError(unknownLoc, "unknown DebugInfo instruction opcode: ")
437 << operands[3];
438
439 auto instructionID = static_cast<spirv::GraphDebugInfoExtInst>(operands[3]);
440 switch (instructionID) {
442 if (operands.size() < 6)
443 return emitError(unknownLoc, "DebugGraph must have graph and string IDs");
444 uint32_t graphID = operands[4];
445 uint32_t stringID = operands[5];
447 graphMap.find(graphID);
448 if (graphIt == graphMap.end())
449 return emitError(unknownLoc, "undefined graph <id> ")
450 << graphID << " in DebugGraph";
451 FailureOr<Location> loc = getDebugLoc(stringID);
452 if (failed(loc))
453 return failure();
454 graphIt->second->setLoc(*loc);
455 break;
456 }
458 if (operands.size() < 7)
459 return emitError(unknownLoc, "DebugOperation must have graph, string and "
460 "instruction IDs");
461 uint32_t stringID = operands[5];
462 FailureOr<Location> loc = getDebugLoc(stringID);
463 if (failed(loc))
464 return failure();
465 SmallVector<uint32_t> operationIDs;
466 operationIDs.append(std::next(operands.begin(), 6), operands.end());
467 for (uint32_t operationID : operationIDs) {
468 DenseMap<uint32_t, Value>::iterator valueIt = valueMap.find(operationID);
469 if (valueIt == valueMap.end())
470 return emitError(unknownLoc, "undefined operation <id> ")
471 << operationID << " in DebugOperation";
472 valueIt->second.setLoc(*loc);
473 }
474 break;
475 }
477 if (operands.size() < 6)
478 return emitError(unknownLoc,
479 "DebugTensor must have tensor and string IDs");
480 uint32_t stringID = operands[5];
481 uint32_t tensorID = operands[4];
482 FailureOr<Location> loc = getDebugLoc(stringID);
483 if (failed(loc))
484 return failure();
485 if (constantMap.contains(tensorID)) {
486 constantLocMap[tensorID] = *loc;
487 break;
488 }
489 DenseMap<uint32_t, Value>::iterator valueIt = valueMap.find(tensorID);
490 if (valueIt == valueMap.end())
491 return emitError(unknownLoc, "undefined tensor <id> ")
492 << tensorID << " in DebugTensor";
493 valueIt->second.setLoc(*loc);
494 break;
495 }
496 }
497
498 return success();
499}
500
502 if (operands.size() < 4) {
503 return emitError(unknownLoc,
504 "OpExtInst must have at least 4 operands, result type "
505 "<id>, result <id>, set <id> and instruction opcode");
506 }
507 if (!extendedInstSets.count(operands[2])) {
508 return emitError(unknownLoc, "undefined set <id> in OpExtInst");
509 }
510 SmallVector<uint32_t, 4> slicedOperands;
511 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
512 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
514 extendedInstSets[operands[2]], operands[3], slicedOperands);
515}
516
517namespace mlir {
518namespace spirv {
519
520template <>
521LogicalResult
523 unsigned wordIndex = 0;
524 if (wordIndex >= words.size()) {
525 return emitError(unknownLoc,
526 "missing Execution Model specification in OpEntryPoint");
527 }
528 auto execModel = spirv::ExecutionModelAttr::get(
529 context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
530 if (wordIndex >= words.size()) {
531 return emitError(unknownLoc, "missing <id> in OpEntryPoint");
532 }
533 // Get the function <id>
534 auto fnID = words[wordIndex++];
535 // Get the function name
536 auto fnName = decodeStringLiteral(words, wordIndex);
537 // Verify that the function <id> matches the fnName
538 auto parsedFunc = getFunction(fnID);
539 if (!parsedFunc) {
540 return emitError(unknownLoc, "no function matching <id> ") << fnID;
541 }
542 if (parsedFunc.getName() != fnName) {
543 // The deserializer uses "spirv_fn_<id>" as the function name if the input
544 // SPIR-V blob does not contain a name for it. We should use a more clear
545 // indication for such case rather than relying on naming details.
546 if (!parsedFunc.getName().starts_with("spirv_fn_"))
547 return emitError(unknownLoc,
548 "function name mismatch between OpEntryPoint "
549 "and OpFunction with <id> ")
550 << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
551 parsedFunc.setName(fnName);
552 }
554 while (wordIndex < words.size()) {
555 auto arg = getGlobalVariable(words[wordIndex]);
556 if (!arg) {
557 return emitError(unknownLoc, "undefined result <id> ")
558 << words[wordIndex] << " while decoding OpEntryPoint";
559 }
560 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
561 wordIndex++;
562 }
563 spirv::EntryPointOp::create(
564 opBuilder, unknownLoc, execModel,
565 SymbolRefAttr::get(opBuilder.getContext(), fnName),
566 opBuilder.getArrayAttr(interface));
567 return success();
568}
569
570template <>
571LogicalResult
573 unsigned wordIndex = 0;
574 if (wordIndex >= words.size()) {
575 return emitError(unknownLoc,
576 "missing function result <id> in OpExecutionMode");
577 }
578 // Get the function <id> to get the name of the function
579 auto fnID = words[wordIndex++];
580 auto fn = getFunction(fnID);
581 if (!fn) {
582 return emitError(unknownLoc, "no function matching <id> ") << fnID;
583 }
584 // Get the Execution mode
585 if (wordIndex >= words.size()) {
586 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
587 }
588 auto execMode = spirv::ExecutionModeAttr::get(
589 context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
590
591 // Get the values
592 SmallVector<Attribute, 4> attrListElems;
593 while (wordIndex < words.size()) {
594 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
595 }
596 auto values = opBuilder.getArrayAttr(attrListElems);
597 spirv::ExecutionModeOp::create(
598 opBuilder, unknownLoc,
599 SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
600 values);
601 return success();
602}
603
604template <>
605LogicalResult
607 unsigned wordIndex = 0;
608 unsigned const wordsSize = words.size();
609 if (wordIndex >= wordsSize)
610 return emitError(unknownLoc,
611 "missing function result <id> in OpExecutionModeId");
612
613 // Get the function <id> to get the name of the function.
614 uint32_t fnID = words[wordIndex++];
615 FuncOp fn = getFunction(fnID);
616 if (!fn)
617 return emitError(unknownLoc, "no function matching <id> ") << fnID;
618
619 // Get the Execution mode.
620 if (wordIndex >= wordsSize)
621 return emitError(unknownLoc, "missing Execution Mode in OpExecutionModeId");
622
623 ExecutionModeAttr execMode = spirv::ExecutionModeAttr::get(
624 context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
625
626 // Get the values.
627 SmallVector<Attribute, 4> attrListElems;
628 while (wordIndex < words.size()) {
629 std::string id = getSpecConstantSymbol(words[wordIndex++]);
630 attrListElems.push_back(FlatSymbolRefAttr::get(context, id));
631 }
632 ArrayAttr values = opBuilder.getArrayAttr(attrListElems);
633 spirv::ExecutionModeIdOp::create(
634 opBuilder, unknownLoc,
635 SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
636 values);
637 return success();
638}
639
640template <>
641LogicalResult
643 if (operands.size() < 3) {
644 return emitError(unknownLoc,
645 "OpFunctionCall must have at least 3 operands");
646 }
647
648 Type resultType = getType(operands[0]);
649 if (!resultType) {
650 return emitError(unknownLoc, "undefined result type from <id> ")
651 << operands[0];
652 }
653
654 // Use null type to mean no result type.
655 if (isVoidType(resultType))
656 resultType = nullptr;
657
658 auto resultID = operands[1];
659 auto functionID = operands[2];
660
661 auto functionName = getFunctionSymbol(functionID);
662
663 SmallVector<Value, 4> arguments;
664 for (auto operand : llvm::drop_begin(operands, 3)) {
665 auto value = getValue(operand);
666 if (!value) {
667 return emitError(unknownLoc, "unknown <id> ")
668 << operand << " used by OpFunctionCall";
669 }
670 arguments.push_back(value);
671 }
672
673 auto opFunctionCall = spirv::FunctionCallOp::create(
674 opBuilder, unknownLoc, resultType,
675 SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
676
677 if (resultType)
678 valueMap[resultID] = opFunctionCall.getResult(0);
679 return success();
680}
681
682template <>
683LogicalResult
685 SmallVector<Type, 1> resultTypes;
686 size_t wordIndex = 0;
687 SmallVector<Value, 4> operands;
689
690 if (wordIndex < words.size()) {
691 auto arg = getValue(words[wordIndex]);
692
693 if (!arg) {
694 return emitError(unknownLoc, "unknown result <id> : ")
695 << words[wordIndex];
696 }
697
698 operands.push_back(arg);
699 wordIndex++;
700 }
701
702 if (wordIndex < words.size()) {
703 auto arg = getValue(words[wordIndex]);
704
705 if (!arg) {
706 return emitError(unknownLoc, "unknown result <id> : ")
707 << words[wordIndex];
708 }
709
710 operands.push_back(arg);
711 wordIndex++;
712 }
713
714 bool isAlignedAttr = false;
715
716 if (wordIndex < words.size()) {
717 auto attrValue = words[wordIndex++];
718 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
719 static_cast<spirv::MemoryAccess>(attrValue));
720 attributes.push_back(
721 opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
722 isAlignedAttr = (attrValue == 2);
723 }
724
725 if (isAlignedAttr && wordIndex < words.size()) {
726 attributes.push_back(opBuilder.getNamedAttr(
727 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
728 }
729
730 if (wordIndex < words.size()) {
731 auto attrValue = words[wordIndex++];
732 auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
733 static_cast<spirv::MemoryAccess>(attrValue));
734 attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
735 }
736
737 if (wordIndex < words.size()) {
738 attributes.push_back(opBuilder.getNamedAttr(
739 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
740 }
741
742 if (wordIndex != words.size()) {
743 return emitError(unknownLoc,
744 "found more operands than expected when deserializing "
745 "spirv::CopyMemoryOp, only ")
746 << wordIndex << " of " << words.size() << " processed";
747 }
748
749 Location loc = createFileLineColLoc(opBuilder);
750 spirv::CopyMemoryOp::create(opBuilder, loc, resultTypes, operands,
751 attributes);
752
753 return success();
754}
755
756template <>
758 ArrayRef<uint32_t> words) {
759 if (words.size() != 4) {
760 return emitError(unknownLoc,
761 "expected 4 words in GenericCastToPtrExplicitOp"
762 " but got : ")
763 << words.size();
764 }
765 SmallVector<Type, 1> resultTypes;
766 SmallVector<Value, 4> operands;
767 uint32_t valueID = 0;
768 auto type = getType(words[0]);
769
770 if (!type)
771 return emitError(unknownLoc, "unknown type result <id> : ") << words[0];
772 resultTypes.push_back(type);
773
774 valueID = words[1];
775
776 auto arg = getValue(words[2]);
777 if (!arg)
778 return emitError(unknownLoc, "unknown result <id> : ") << words[2];
779 operands.push_back(arg);
780
781 Location loc = createFileLineColLoc(opBuilder);
782 Operation *op = spirv::GenericCastToPtrExplicitOp::create(
783 opBuilder, loc, resultTypes, operands);
784 valueMap[valueID] = op->getResult(0);
785 return success();
786}
787
788// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
789// various Deserializer::processOp<...>() specializations.
790#define GET_DESERIALIZATION_FNS
791#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
792
793} // namespace spirv
794} // namespace mlir
return success()
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
static NameLoc getLocFromDebugInfoString(OpBuilder &builder, StringRef source)
Returns a NameLoc location from the given debug info string.
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
Location objects represent source locations information in MLIR.
Definition Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, PropertyRef properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processDebugInfoExtInst(ArrayRef< uint32_t > operands, bool deferInstructions)
Processes a SPIR-V OpExtInst with given operands for a DebugInfo extension instruction.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processExtInst(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpExtInst with given operands.
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName, uint32_t instructionID, ArrayRef< uint32_t > words)
Dispatches the deserialization of extended instruction set operation based on the extended instructio...
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
spirv::EXTSpecConstantCompositeReplicateOp getSpecConstantCompositeReplicate(uint32_t id)
Gets the replicated composite specialization constant with the given result <id>.
LogicalResult processOp(ArrayRef< uint32_t > words)
Method to deserialize an operation in the SPIR-V dialect that is a mirror of an instruction in the SP...
Type getUndefType(uint32_t id)
Get the type associated with the result <id> of an OpUndef.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
void mergeLongCompositeContinuations(spirv::Opcode opcode, ArrayRef< uint32_t > &operands, SmallVectorImpl< uint32_t > &mergedStorage)
If opcode is a SPV_INTEL_long_composites splittable opcode and the next binary instruction(s) are mat...
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, ArrayRef< uint32_t > words)
Method to dispatch to the specialized deserialization function for an operation in SPIR-V dialect tha...
LogicalResult processOpWithoutGrammarAttr(ArrayRef< uint32_t > words, StringRef opName, bool hasResult, unsigned numOperands)
Processes a SPIR-V instruction from the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processUndef(ArrayRef< uint32_t > operands)
Processes a OpUndef instruction.
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
constexpr llvm::StringLiteral extDebugInfo
Extension set name for non-semantic graph debug info.
std::optional< spirv::Opcode > getContinuationOpcode(spirv::Opcode parent)
Returns the SPV_INTEL_long_composites continuation opcode that may follow parent, or std::nullopt if ...
constexpr bool isValidGraphDebugInfoExtInst(uint32_t opcode)
constexpr StringRef attributeName()
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
GraphDebugInfoExtInst
Instruction opcodes in the NonSemantic.Graph.DebugInfo.1 extended instruction set.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)