MLIR  18.0.0git
DeserializeOps.cpp
Go to the documentation of this file.
1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Location.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
23 #include <optional>
24 
25 using namespace mlir;
26 
27 #define DEBUG_TYPE "spirv-deserialization"
28 
29 //===----------------------------------------------------------------------===//
30 // Utility Functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Extracts the opcode from the given first word of a SPIR-V instruction.
34 static inline spirv::Opcode extractOpcode(uint32_t word) {
35  return static_cast<spirv::Opcode>(word & 0xffff);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Instruction
40 //===----------------------------------------------------------------------===//
41 
42 Value spirv::Deserializer::getValue(uint32_t id) {
43  if (auto constInfo = getConstant(id)) {
44  // Materialize a `spirv.Constant` op at every use site.
45  return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
46  constInfo->first);
47  }
48  if (auto varOp = getGlobalVariable(id)) {
49  auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
50  unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
51  return addressOfOp.getPointer();
52  }
53  if (auto constOp = getSpecConstant(id)) {
54  auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
55  unknownLoc, constOp.getDefaultValue().getType(),
56  SymbolRefAttr::get(constOp.getOperation()));
57  return referenceOfOp.getReference();
58  }
59  if (auto constCompositeOp = getSpecConstantComposite(id)) {
60  auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
61  unknownLoc, constCompositeOp.getType(),
62  SymbolRefAttr::get(constCompositeOp.getOperation()));
63  return referenceOfOp.getReference();
64  }
65  if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
66  return materializeSpecConstantOperation(
67  id, specConstOperationInfo->enclodesOpcode,
68  specConstOperationInfo->resultTypeID,
69  specConstOperationInfo->enclosedOpOperands);
70  }
71  if (auto undef = getUndefType(id)) {
72  return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
73  }
74  return valueMap.lookup(id);
75 }
76 
77 LogicalResult spirv::Deserializer::sliceInstruction(
78  spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
79  std::optional<spirv::Opcode> expectedOpcode) {
80  auto binarySize = binary.size();
81  if (curOffset >= binarySize) {
82  return emitError(unknownLoc, "expected ")
83  << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
84  : "more")
85  << " instruction";
86  }
87 
88  // For each instruction, get its word count from the first word to slice it
89  // from the stream properly, and then dispatch to the instruction handler.
90 
91  uint32_t wordCount = binary[curOffset] >> 16;
92 
93  if (wordCount == 0)
94  return emitError(unknownLoc, "word count cannot be zero");
95 
96  uint32_t nextOffset = curOffset + wordCount;
97  if (nextOffset > binarySize)
98  return emitError(unknownLoc, "insufficient words for the last instruction");
99 
100  opcode = extractOpcode(binary[curOffset]);
101  operands = binary.slice(curOffset + 1, wordCount - 1);
102  curOffset = nextOffset;
103  return success();
104 }
105 
106 LogicalResult spirv::Deserializer::processInstruction(
107  spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
108  LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
109  << spirv::stringifyOpcode(opcode) << "\n");
110 
111  // First dispatch all the instructions whose opcode does not correspond to
112  // those that have a direct mirror in the SPIR-V dialect
113  switch (opcode) {
114  case spirv::Opcode::OpCapability:
115  return processCapability(operands);
116  case spirv::Opcode::OpExtension:
117  return processExtension(operands);
118  case spirv::Opcode::OpExtInst:
119  return processExtInst(operands);
120  case spirv::Opcode::OpExtInstImport:
121  return processExtInstImport(operands);
122  case spirv::Opcode::OpMemberName:
123  return processMemberName(operands);
124  case spirv::Opcode::OpMemoryModel:
125  return processMemoryModel(operands);
126  case spirv::Opcode::OpEntryPoint:
127  case spirv::Opcode::OpExecutionMode:
128  if (deferInstructions) {
129  deferredInstructions.emplace_back(opcode, operands);
130  return success();
131  }
132  break;
133  case spirv::Opcode::OpVariable:
134  if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135  return processGlobalVariable(operands);
136  }
137  break;
138  case spirv::Opcode::OpLine:
139  return processDebugLine(operands);
140  case spirv::Opcode::OpNoLine:
141  clearDebugLine();
142  return success();
143  case spirv::Opcode::OpName:
144  return processName(operands);
145  case spirv::Opcode::OpString:
146  return processDebugString(operands);
147  case spirv::Opcode::OpModuleProcessed:
148  case spirv::Opcode::OpSource:
149  case spirv::Opcode::OpSourceContinued:
150  case spirv::Opcode::OpSourceExtension:
151  // TODO: This is debug information embedded in the binary which should be
152  // translated into the spirv.module.
153  return success();
154  case spirv::Opcode::OpTypeVoid:
155  case spirv::Opcode::OpTypeBool:
156  case spirv::Opcode::OpTypeInt:
157  case spirv::Opcode::OpTypeFloat:
158  case spirv::Opcode::OpTypeVector:
159  case spirv::Opcode::OpTypeMatrix:
160  case spirv::Opcode::OpTypeArray:
161  case spirv::Opcode::OpTypeFunction:
162  case spirv::Opcode::OpTypeImage:
163  case spirv::Opcode::OpTypeSampledImage:
164  case spirv::Opcode::OpTypeRuntimeArray:
165  case spirv::Opcode::OpTypeStruct:
166  case spirv::Opcode::OpTypePointer:
167  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168  case spirv::Opcode::OpTypeCooperativeMatrixNV:
169  return processType(opcode, operands);
170  case spirv::Opcode::OpTypeForwardPointer:
171  return processTypeForwardPointer(operands);
172  case spirv::Opcode::OpTypeJointMatrixINTEL:
173  return processType(opcode, operands);
174  case spirv::Opcode::OpConstant:
175  return processConstant(operands, /*isSpec=*/false);
176  case spirv::Opcode::OpSpecConstant:
177  return processConstant(operands, /*isSpec=*/true);
178  case spirv::Opcode::OpConstantComposite:
179  return processConstantComposite(operands);
180  case spirv::Opcode::OpSpecConstantComposite:
181  return processSpecConstantComposite(operands);
182  case spirv::Opcode::OpSpecConstantOp:
183  return processSpecConstantOperation(operands);
184  case spirv::Opcode::OpConstantTrue:
185  return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
186  case spirv::Opcode::OpSpecConstantTrue:
187  return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
188  case spirv::Opcode::OpConstantFalse:
189  return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
190  case spirv::Opcode::OpSpecConstantFalse:
191  return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
192  case spirv::Opcode::OpConstantNull:
193  return processConstantNull(operands);
194  case spirv::Opcode::OpDecorate:
195  return processDecoration(operands);
196  case spirv::Opcode::OpMemberDecorate:
197  return processMemberDecoration(operands);
198  case spirv::Opcode::OpFunction:
199  return processFunction(operands);
200  case spirv::Opcode::OpLabel:
201  return processLabel(operands);
202  case spirv::Opcode::OpBranch:
203  return processBranch(operands);
204  case spirv::Opcode::OpBranchConditional:
205  return processBranchConditional(operands);
206  case spirv::Opcode::OpSelectionMerge:
207  return processSelectionMerge(operands);
208  case spirv::Opcode::OpLoopMerge:
209  return processLoopMerge(operands);
210  case spirv::Opcode::OpPhi:
211  return processPhi(operands);
212  case spirv::Opcode::OpUndef:
213  return processUndef(operands);
214  default:
215  break;
216  }
217  return dispatchToAutogenDeserialization(opcode, operands);
218 }
219 
220 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
221  ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
222  unsigned numOperands) {
223  SmallVector<Type, 1> resultTypes;
224  uint32_t valueID = 0;
225 
226  size_t wordIndex = 0;
227  if (hasResult) {
228  if (wordIndex >= words.size())
229  return emitError(unknownLoc,
230  "expected result type <id> while deserializing for ")
231  << opName;
232 
233  // Decode the type <id>
234  auto type = getType(words[wordIndex]);
235  if (!type)
236  return emitError(unknownLoc, "unknown type result <id>: ")
237  << words[wordIndex];
238  resultTypes.push_back(type);
239  ++wordIndex;
240 
241  // Decode the result <id>
242  if (wordIndex >= words.size())
243  return emitError(unknownLoc,
244  "expected result <id> while deserializing for ")
245  << opName;
246  valueID = words[wordIndex];
247  ++wordIndex;
248  }
249 
250  SmallVector<Value, 4> operands;
252 
253  // Decode operands
254  size_t operandIndex = 0;
255  for (; operandIndex < numOperands && wordIndex < words.size();
256  ++operandIndex, ++wordIndex) {
257  auto arg = getValue(words[wordIndex]);
258  if (!arg)
259  return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
260  operands.push_back(arg);
261  }
262  if (operandIndex != numOperands) {
263  return emitError(
264  unknownLoc,
265  "found less operands than expected when deserializing for ")
266  << opName << "; only " << operandIndex << " of " << numOperands
267  << " processed";
268  }
269  if (wordIndex != words.size()) {
270  return emitError(
271  unknownLoc,
272  "found more operands than expected when deserializing for ")
273  << opName << "; only " << wordIndex << " of " << words.size()
274  << " processed";
275  }
276 
277  // Attach attributes from decorations
278  if (decorations.count(valueID)) {
279  auto attrs = decorations[valueID].getAttrs();
280  attributes.append(attrs.begin(), attrs.end());
281  }
282 
283  // Create the op and update bookkeeping maps
284  Location loc = createFileLineColLoc(opBuilder);
285  OperationState opState(loc, opName);
286  opState.addOperands(operands);
287  if (hasResult)
288  opState.addTypes(resultTypes);
289  opState.addAttributes(attributes);
290  Operation *op = opBuilder.create(opState);
291  if (hasResult)
292  valueMap[valueID] = op->getResult(0);
293 
294  if (op->hasTrait<OpTrait::IsTerminator>())
295  clearDebugLine();
296 
297  return success();
298 }
299 
300 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
301  if (operands.size() != 2) {
302  return emitError(unknownLoc, "OpUndef instruction must have two operands");
303  }
304  auto type = getType(operands[0]);
305  if (!type) {
306  return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
307  }
308  undefMap[operands[1]] = type;
309  return success();
310 }
311 
312 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
313  if (operands.size() < 4) {
314  return emitError(unknownLoc,
315  "OpExtInst must have at least 4 operands, result type "
316  "<id>, result <id>, set <id> and instruction opcode");
317  }
318  if (!extendedInstSets.count(operands[2])) {
319  return emitError(unknownLoc, "undefined set <id> in OpExtInst");
320  }
321  SmallVector<uint32_t, 4> slicedOperands;
322  slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
323  slicedOperands.append(std::next(operands.begin(), 4), operands.end());
324  return dispatchToExtensionSetAutogenDeserialization(
325  extendedInstSets[operands[2]], operands[3], slicedOperands);
326 }
327 
328 namespace mlir {
329 namespace spirv {
330 
331 template <>
333 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
334  unsigned wordIndex = 0;
335  if (wordIndex >= words.size()) {
336  return emitError(unknownLoc,
337  "missing Execution Model specification in OpEntryPoint");
338  }
339  auto execModel = spirv::ExecutionModelAttr::get(
340  context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
341  if (wordIndex >= words.size()) {
342  return emitError(unknownLoc, "missing <id> in OpEntryPoint");
343  }
344  // Get the function <id>
345  auto fnID = words[wordIndex++];
346  // Get the function name
347  auto fnName = decodeStringLiteral(words, wordIndex);
348  // Verify that the function <id> matches the fnName
349  auto parsedFunc = getFunction(fnID);
350  if (!parsedFunc) {
351  return emitError(unknownLoc, "no function matching <id> ") << fnID;
352  }
353  if (parsedFunc.getName() != fnName) {
354  // The deserializer uses "spirv_fn_<id>" as the function name if the input
355  // SPIR-V blob does not contain a name for it. We should use a more clear
356  // indication for such case rather than relying on naming details.
357  if (!parsedFunc.getName().startswith("spirv_fn_"))
358  return emitError(unknownLoc,
359  "function name mismatch between OpEntryPoint "
360  "and OpFunction with <id> ")
361  << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
362  parsedFunc.setName(fnName);
363  }
364  SmallVector<Attribute, 4> interface;
365  while (wordIndex < words.size()) {
366  auto arg = getGlobalVariable(words[wordIndex]);
367  if (!arg) {
368  return emitError(unknownLoc, "undefined result <id> ")
369  << words[wordIndex] << " while decoding OpEntryPoint";
370  }
371  interface.push_back(SymbolRefAttr::get(arg.getOperation()));
372  wordIndex++;
373  }
374  opBuilder.create<spirv::EntryPointOp>(
375  unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
376  opBuilder.getArrayAttr(interface));
377  return success();
378 }
379 
380 template <>
382 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
383  unsigned wordIndex = 0;
384  if (wordIndex >= words.size()) {
385  return emitError(unknownLoc,
386  "missing function result <id> in OpExecutionMode");
387  }
388  // Get the function <id> to get the name of the function
389  auto fnID = words[wordIndex++];
390  auto fn = getFunction(fnID);
391  if (!fn) {
392  return emitError(unknownLoc, "no function matching <id> ") << fnID;
393  }
394  // Get the Execution mode
395  if (wordIndex >= words.size()) {
396  return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
397  }
398  auto execMode = spirv::ExecutionModeAttr::get(
399  context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
400 
401  // Get the values
402  SmallVector<Attribute, 4> attrListElems;
403  while (wordIndex < words.size()) {
404  attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
405  }
406  auto values = opBuilder.getArrayAttr(attrListElems);
407  opBuilder.create<spirv::ExecutionModeOp>(
408  unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
409  execMode, values);
410  return success();
411 }
412 
413 template <>
415 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
416  if (operands.size() < 3) {
417  return emitError(unknownLoc,
418  "OpFunctionCall must have at least 3 operands");
419  }
420 
421  Type resultType = getType(operands[0]);
422  if (!resultType) {
423  return emitError(unknownLoc, "undefined result type from <id> ")
424  << operands[0];
425  }
426 
427  // Use null type to mean no result type.
428  if (isVoidType(resultType))
429  resultType = nullptr;
430 
431  auto resultID = operands[1];
432  auto functionID = operands[2];
433 
434  auto functionName = getFunctionSymbol(functionID);
435 
436  SmallVector<Value, 4> arguments;
437  for (auto operand : llvm::drop_begin(operands, 3)) {
438  auto value = getValue(operand);
439  if (!value) {
440  return emitError(unknownLoc, "unknown <id> ")
441  << operand << " used by OpFunctionCall";
442  }
443  arguments.push_back(value);
444  }
445 
446  auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
447  unknownLoc, resultType,
448  SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
449 
450  if (resultType)
451  valueMap[resultID] = opFunctionCall.getResult(0);
452  return success();
453 }
454 
455 template <>
457 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
458  SmallVector<Type, 1> resultTypes;
459  size_t wordIndex = 0;
460  SmallVector<Value, 4> operands;
462 
463  if (wordIndex < words.size()) {
464  auto arg = getValue(words[wordIndex]);
465 
466  if (!arg) {
467  return emitError(unknownLoc, "unknown result <id> : ")
468  << words[wordIndex];
469  }
470 
471  operands.push_back(arg);
472  wordIndex++;
473  }
474 
475  if (wordIndex < words.size()) {
476  auto arg = getValue(words[wordIndex]);
477 
478  if (!arg) {
479  return emitError(unknownLoc, "unknown result <id> : ")
480  << words[wordIndex];
481  }
482 
483  operands.push_back(arg);
484  wordIndex++;
485  }
486 
487  bool isAlignedAttr = false;
488 
489  if (wordIndex < words.size()) {
490  auto attrValue = words[wordIndex++];
491  auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
492  static_cast<spirv::MemoryAccess>(attrValue));
493  attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
494  isAlignedAttr = (attrValue == 2);
495  }
496 
497  if (isAlignedAttr && wordIndex < words.size()) {
498  attributes.push_back(opBuilder.getNamedAttr(
499  "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
500  }
501 
502  if (wordIndex < words.size()) {
503  auto attrValue = words[wordIndex++];
504  auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
505  static_cast<spirv::MemoryAccess>(attrValue));
506  attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
507  }
508 
509  if (wordIndex < words.size()) {
510  attributes.push_back(opBuilder.getNamedAttr(
511  "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
512  }
513 
514  if (wordIndex != words.size()) {
515  return emitError(unknownLoc,
516  "found more operands than expected when deserializing "
517  "spirv::CopyMemoryOp, only ")
518  << wordIndex << " of " << words.size() << " processed";
519  }
520 
521  Location loc = createFileLineColLoc(opBuilder);
522  opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
523 
524  return success();
525 }
526 
527 template <>
528 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
529  ArrayRef<uint32_t> words) {
530  if (words.size() != 4) {
531  return emitError(unknownLoc,
532  "expected 4 words in GenericCastToPtrExplicitOp"
533  " but got : ")
534  << words.size();
535  }
536  SmallVector<Type, 1> resultTypes;
537  SmallVector<Value, 4> operands;
538  uint32_t valueID = 0;
539  auto type = getType(words[0]);
540 
541  if (!type)
542  return emitError(unknownLoc, "unknown type result <id> : ") << words[0];
543  resultTypes.push_back(type);
544 
545  valueID = words[1];
546 
547  auto arg = getValue(words[2]);
548  if (!arg)
549  return emitError(unknownLoc, "unknown result <id> : ") << words[2];
550  operands.push_back(arg);
551 
552  Location loc = createFileLineColLoc(opBuilder);
553  Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>(
554  loc, resultTypes, operands);
555  valueMap[valueID] = op->getResult(0);
556  return success();
557 }
558 
559 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
560 // various Deserializer::processOp<...>() specializations.
561 #define GET_DESERIALIZATION_FNS
562 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
563 
564 } // namespace spirv
565 } // namespace mlir
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:762
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.