MLIR  20.0.0git
DeserializeOps.cpp
Go to the documentation of this file.
1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Location.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
23 #include <optional>
24 
25 using namespace mlir;
26 
27 #define DEBUG_TYPE "spirv-deserialization"
28 
29 //===----------------------------------------------------------------------===//
30 // Utility Functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Extracts the opcode from the given first word of a SPIR-V instruction.
34 static inline spirv::Opcode extractOpcode(uint32_t word) {
35  return static_cast<spirv::Opcode>(word & 0xffff);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Instruction
40 //===----------------------------------------------------------------------===//
41 
42 Value spirv::Deserializer::getValue(uint32_t id) {
43  if (auto constInfo = getConstant(id)) {
44  // Materialize a `spirv.Constant` op at every use site.
45  return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
46  constInfo->first);
47  }
48  if (auto varOp = getGlobalVariable(id)) {
49  auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
50  unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
51  return addressOfOp.getPointer();
52  }
53  if (auto constOp = getSpecConstant(id)) {
54  auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
55  unknownLoc, constOp.getDefaultValue().getType(),
56  SymbolRefAttr::get(constOp.getOperation()));
57  return referenceOfOp.getReference();
58  }
59  if (auto constCompositeOp = getSpecConstantComposite(id)) {
60  auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
61  unknownLoc, constCompositeOp.getType(),
62  SymbolRefAttr::get(constCompositeOp.getOperation()));
63  return referenceOfOp.getReference();
64  }
65  if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
66  return materializeSpecConstantOperation(
67  id, specConstOperationInfo->enclodesOpcode,
68  specConstOperationInfo->resultTypeID,
69  specConstOperationInfo->enclosedOpOperands);
70  }
71  if (auto undef = getUndefType(id)) {
72  return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
73  }
74  return valueMap.lookup(id);
75 }
76 
77 LogicalResult spirv::Deserializer::sliceInstruction(
78  spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
79  std::optional<spirv::Opcode> expectedOpcode) {
80  auto binarySize = binary.size();
81  if (curOffset >= binarySize) {
82  return emitError(unknownLoc, "expected ")
83  << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
84  : "more")
85  << " instruction";
86  }
87 
88  // For each instruction, get its word count from the first word to slice it
89  // from the stream properly, and then dispatch to the instruction handler.
90 
91  uint32_t wordCount = binary[curOffset] >> 16;
92 
93  if (wordCount == 0)
94  return emitError(unknownLoc, "word count cannot be zero");
95 
96  uint32_t nextOffset = curOffset + wordCount;
97  if (nextOffset > binarySize)
98  return emitError(unknownLoc, "insufficient words for the last instruction");
99 
100  opcode = extractOpcode(binary[curOffset]);
101  operands = binary.slice(curOffset + 1, wordCount - 1);
102  curOffset = nextOffset;
103  return success();
104 }
105 
106 LogicalResult spirv::Deserializer::processInstruction(
107  spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
108  LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
109  << spirv::stringifyOpcode(opcode) << "\n");
110 
111  // First dispatch all the instructions whose opcode does not correspond to
112  // those that have a direct mirror in the SPIR-V dialect
113  switch (opcode) {
114  case spirv::Opcode::OpCapability:
115  return processCapability(operands);
116  case spirv::Opcode::OpExtension:
117  return processExtension(operands);
118  case spirv::Opcode::OpExtInst:
119  return processExtInst(operands);
120  case spirv::Opcode::OpExtInstImport:
121  return processExtInstImport(operands);
122  case spirv::Opcode::OpMemberName:
123  return processMemberName(operands);
124  case spirv::Opcode::OpMemoryModel:
125  return processMemoryModel(operands);
126  case spirv::Opcode::OpEntryPoint:
127  case spirv::Opcode::OpExecutionMode:
128  if (deferInstructions) {
129  deferredInstructions.emplace_back(opcode, operands);
130  return success();
131  }
132  break;
133  case spirv::Opcode::OpVariable:
134  if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135  return processGlobalVariable(operands);
136  }
137  break;
138  case spirv::Opcode::OpLine:
139  return processDebugLine(operands);
140  case spirv::Opcode::OpNoLine:
141  clearDebugLine();
142  return success();
143  case spirv::Opcode::OpName:
144  return processName(operands);
145  case spirv::Opcode::OpString:
146  return processDebugString(operands);
147  case spirv::Opcode::OpModuleProcessed:
148  case spirv::Opcode::OpSource:
149  case spirv::Opcode::OpSourceContinued:
150  case spirv::Opcode::OpSourceExtension:
151  // TODO: This is debug information embedded in the binary which should be
152  // translated into the spirv.module.
153  return success();
154  case spirv::Opcode::OpTypeVoid:
155  case spirv::Opcode::OpTypeBool:
156  case spirv::Opcode::OpTypeInt:
157  case spirv::Opcode::OpTypeFloat:
158  case spirv::Opcode::OpTypeVector:
159  case spirv::Opcode::OpTypeMatrix:
160  case spirv::Opcode::OpTypeArray:
161  case spirv::Opcode::OpTypeFunction:
162  case spirv::Opcode::OpTypeImage:
163  case spirv::Opcode::OpTypeSampledImage:
164  case spirv::Opcode::OpTypeRuntimeArray:
165  case spirv::Opcode::OpTypeStruct:
166  case spirv::Opcode::OpTypePointer:
167  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168  return processType(opcode, operands);
169  case spirv::Opcode::OpTypeForwardPointer:
170  return processTypeForwardPointer(operands);
171  case spirv::Opcode::OpConstant:
172  return processConstant(operands, /*isSpec=*/false);
173  case spirv::Opcode::OpSpecConstant:
174  return processConstant(operands, /*isSpec=*/true);
175  case spirv::Opcode::OpConstantComposite:
176  return processConstantComposite(operands);
177  case spirv::Opcode::OpSpecConstantComposite:
178  return processSpecConstantComposite(operands);
179  case spirv::Opcode::OpSpecConstantOp:
180  return processSpecConstantOperation(operands);
181  case spirv::Opcode::OpConstantTrue:
182  return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
183  case spirv::Opcode::OpSpecConstantTrue:
184  return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
185  case spirv::Opcode::OpConstantFalse:
186  return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
187  case spirv::Opcode::OpSpecConstantFalse:
188  return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
189  case spirv::Opcode::OpConstantNull:
190  return processConstantNull(operands);
191  case spirv::Opcode::OpDecorate:
192  return processDecoration(operands);
193  case spirv::Opcode::OpMemberDecorate:
194  return processMemberDecoration(operands);
195  case spirv::Opcode::OpFunction:
196  return processFunction(operands);
197  case spirv::Opcode::OpLabel:
198  return processLabel(operands);
199  case spirv::Opcode::OpBranch:
200  return processBranch(operands);
201  case spirv::Opcode::OpBranchConditional:
202  return processBranchConditional(operands);
203  case spirv::Opcode::OpSelectionMerge:
204  return processSelectionMerge(operands);
205  case spirv::Opcode::OpLoopMerge:
206  return processLoopMerge(operands);
207  case spirv::Opcode::OpPhi:
208  return processPhi(operands);
209  case spirv::Opcode::OpUndef:
210  return processUndef(operands);
211  default:
212  break;
213  }
214  return dispatchToAutogenDeserialization(opcode, operands);
215 }
216 
217 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
218  ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
219  unsigned numOperands) {
220  SmallVector<Type, 1> resultTypes;
221  uint32_t valueID = 0;
222 
223  size_t wordIndex = 0;
224  if (hasResult) {
225  if (wordIndex >= words.size())
226  return emitError(unknownLoc,
227  "expected result type <id> while deserializing for ")
228  << opName;
229 
230  // Decode the type <id>
231  auto type = getType(words[wordIndex]);
232  if (!type)
233  return emitError(unknownLoc, "unknown type result <id>: ")
234  << words[wordIndex];
235  resultTypes.push_back(type);
236  ++wordIndex;
237 
238  // Decode the result <id>
239  if (wordIndex >= words.size())
240  return emitError(unknownLoc,
241  "expected result <id> while deserializing for ")
242  << opName;
243  valueID = words[wordIndex];
244  ++wordIndex;
245  }
246 
247  SmallVector<Value, 4> operands;
249 
250  // Decode operands
251  size_t operandIndex = 0;
252  for (; operandIndex < numOperands && wordIndex < words.size();
253  ++operandIndex, ++wordIndex) {
254  auto arg = getValue(words[wordIndex]);
255  if (!arg)
256  return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
257  operands.push_back(arg);
258  }
259  if (operandIndex != numOperands) {
260  return emitError(
261  unknownLoc,
262  "found less operands than expected when deserializing for ")
263  << opName << "; only " << operandIndex << " of " << numOperands
264  << " processed";
265  }
266  if (wordIndex != words.size()) {
267  return emitError(
268  unknownLoc,
269  "found more operands than expected when deserializing for ")
270  << opName << "; only " << wordIndex << " of " << words.size()
271  << " processed";
272  }
273 
274  // Attach attributes from decorations
275  if (decorations.count(valueID)) {
276  auto attrs = decorations[valueID].getAttrs();
277  attributes.append(attrs.begin(), attrs.end());
278  }
279 
280  // Create the op and update bookkeeping maps
281  Location loc = createFileLineColLoc(opBuilder);
282  OperationState opState(loc, opName);
283  opState.addOperands(operands);
284  if (hasResult)
285  opState.addTypes(resultTypes);
286  opState.addAttributes(attributes);
287  Operation *op = opBuilder.create(opState);
288  if (hasResult)
289  valueMap[valueID] = op->getResult(0);
290 
291  if (op->hasTrait<OpTrait::IsTerminator>())
292  clearDebugLine();
293 
294  return success();
295 }
296 
297 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
298  if (operands.size() != 2) {
299  return emitError(unknownLoc, "OpUndef instruction must have two operands");
300  }
301  auto type = getType(operands[0]);
302  if (!type) {
303  return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
304  }
305  undefMap[operands[1]] = type;
306  return success();
307 }
308 
309 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
310  if (operands.size() < 4) {
311  return emitError(unknownLoc,
312  "OpExtInst must have at least 4 operands, result type "
313  "<id>, result <id>, set <id> and instruction opcode");
314  }
315  if (!extendedInstSets.count(operands[2])) {
316  return emitError(unknownLoc, "undefined set <id> in OpExtInst");
317  }
318  SmallVector<uint32_t, 4> slicedOperands;
319  slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
320  slicedOperands.append(std::next(operands.begin(), 4), operands.end());
321  return dispatchToExtensionSetAutogenDeserialization(
322  extendedInstSets[operands[2]], operands[3], slicedOperands);
323 }
324 
325 namespace mlir {
326 namespace spirv {
327 
328 template <>
329 LogicalResult
330 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
331  unsigned wordIndex = 0;
332  if (wordIndex >= words.size()) {
333  return emitError(unknownLoc,
334  "missing Execution Model specification in OpEntryPoint");
335  }
336  auto execModel = spirv::ExecutionModelAttr::get(
337  context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
338  if (wordIndex >= words.size()) {
339  return emitError(unknownLoc, "missing <id> in OpEntryPoint");
340  }
341  // Get the function <id>
342  auto fnID = words[wordIndex++];
343  // Get the function name
344  auto fnName = decodeStringLiteral(words, wordIndex);
345  // Verify that the function <id> matches the fnName
346  auto parsedFunc = getFunction(fnID);
347  if (!parsedFunc) {
348  return emitError(unknownLoc, "no function matching <id> ") << fnID;
349  }
350  if (parsedFunc.getName() != fnName) {
351  // The deserializer uses "spirv_fn_<id>" as the function name if the input
352  // SPIR-V blob does not contain a name for it. We should use a more clear
353  // indication for such case rather than relying on naming details.
354  if (!parsedFunc.getName().starts_with("spirv_fn_"))
355  return emitError(unknownLoc,
356  "function name mismatch between OpEntryPoint "
357  "and OpFunction with <id> ")
358  << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
359  parsedFunc.setName(fnName);
360  }
361  SmallVector<Attribute, 4> interface;
362  while (wordIndex < words.size()) {
363  auto arg = getGlobalVariable(words[wordIndex]);
364  if (!arg) {
365  return emitError(unknownLoc, "undefined result <id> ")
366  << words[wordIndex] << " while decoding OpEntryPoint";
367  }
368  interface.push_back(SymbolRefAttr::get(arg.getOperation()));
369  wordIndex++;
370  }
371  opBuilder.create<spirv::EntryPointOp>(
372  unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
373  opBuilder.getArrayAttr(interface));
374  return success();
375 }
376 
377 template <>
378 LogicalResult
379 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
380  unsigned wordIndex = 0;
381  if (wordIndex >= words.size()) {
382  return emitError(unknownLoc,
383  "missing function result <id> in OpExecutionMode");
384  }
385  // Get the function <id> to get the name of the function
386  auto fnID = words[wordIndex++];
387  auto fn = getFunction(fnID);
388  if (!fn) {
389  return emitError(unknownLoc, "no function matching <id> ") << fnID;
390  }
391  // Get the Execution mode
392  if (wordIndex >= words.size()) {
393  return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
394  }
395  auto execMode = spirv::ExecutionModeAttr::get(
396  context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
397 
398  // Get the values
399  SmallVector<Attribute, 4> attrListElems;
400  while (wordIndex < words.size()) {
401  attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
402  }
403  auto values = opBuilder.getArrayAttr(attrListElems);
404  opBuilder.create<spirv::ExecutionModeOp>(
405  unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
406  execMode, values);
407  return success();
408 }
409 
410 template <>
411 LogicalResult
412 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
413  if (operands.size() < 3) {
414  return emitError(unknownLoc,
415  "OpFunctionCall must have at least 3 operands");
416  }
417 
418  Type resultType = getType(operands[0]);
419  if (!resultType) {
420  return emitError(unknownLoc, "undefined result type from <id> ")
421  << operands[0];
422  }
423 
424  // Use null type to mean no result type.
425  if (isVoidType(resultType))
426  resultType = nullptr;
427 
428  auto resultID = operands[1];
429  auto functionID = operands[2];
430 
431  auto functionName = getFunctionSymbol(functionID);
432 
433  SmallVector<Value, 4> arguments;
434  for (auto operand : llvm::drop_begin(operands, 3)) {
435  auto value = getValue(operand);
436  if (!value) {
437  return emitError(unknownLoc, "unknown <id> ")
438  << operand << " used by OpFunctionCall";
439  }
440  arguments.push_back(value);
441  }
442 
443  auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
444  unknownLoc, resultType,
445  SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
446 
447  if (resultType)
448  valueMap[resultID] = opFunctionCall.getResult(0);
449  return success();
450 }
451 
452 template <>
453 LogicalResult
454 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
455  SmallVector<Type, 1> resultTypes;
456  size_t wordIndex = 0;
457  SmallVector<Value, 4> operands;
459 
460  if (wordIndex < words.size()) {
461  auto arg = getValue(words[wordIndex]);
462 
463  if (!arg) {
464  return emitError(unknownLoc, "unknown result <id> : ")
465  << words[wordIndex];
466  }
467 
468  operands.push_back(arg);
469  wordIndex++;
470  }
471 
472  if (wordIndex < words.size()) {
473  auto arg = getValue(words[wordIndex]);
474 
475  if (!arg) {
476  return emitError(unknownLoc, "unknown result <id> : ")
477  << words[wordIndex];
478  }
479 
480  operands.push_back(arg);
481  wordIndex++;
482  }
483 
484  bool isAlignedAttr = false;
485 
486  if (wordIndex < words.size()) {
487  auto attrValue = words[wordIndex++];
488  auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
489  static_cast<spirv::MemoryAccess>(attrValue));
490  attributes.push_back(
491  opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
492  isAlignedAttr = (attrValue == 2);
493  }
494 
495  if (isAlignedAttr && wordIndex < words.size()) {
496  attributes.push_back(opBuilder.getNamedAttr(
497  "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
498  }
499 
500  if (wordIndex < words.size()) {
501  auto attrValue = words[wordIndex++];
502  auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
503  static_cast<spirv::MemoryAccess>(attrValue));
504  attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
505  }
506 
507  if (wordIndex < words.size()) {
508  attributes.push_back(opBuilder.getNamedAttr(
509  "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
510  }
511 
512  if (wordIndex != words.size()) {
513  return emitError(unknownLoc,
514  "found more operands than expected when deserializing "
515  "spirv::CopyMemoryOp, only ")
516  << wordIndex << " of " << words.size() << " processed";
517  }
518 
519  Location loc = createFileLineColLoc(opBuilder);
520  opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
521 
522  return success();
523 }
524 
525 template <>
526 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
527  ArrayRef<uint32_t> words) {
528  if (words.size() != 4) {
529  return emitError(unknownLoc,
530  "expected 4 words in GenericCastToPtrExplicitOp"
531  " but got : ")
532  << words.size();
533  }
534  SmallVector<Type, 1> resultTypes;
535  SmallVector<Value, 4> operands;
536  uint32_t valueID = 0;
537  auto type = getType(words[0]);
538 
539  if (!type)
540  return emitError(unknownLoc, "unknown type result <id> : ") << words[0];
541  resultTypes.push_back(type);
542 
543  valueID = words[1];
544 
545  auto arg = getValue(words[2]);
546  if (!arg)
547  return emitError(unknownLoc, "unknown result <id> : ") << words[2];
548  operands.push_back(arg);
549 
550  Location loc = createFileLineColLoc(opBuilder);
551  Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>(
552  loc, resultTypes, operands);
553  valueMap[valueID] = op->getResult(0);
554  return success();
555 }
556 
557 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
558 // various Deserializer::processOp<...>() specializations.
559 #define GET_DESERIALIZATION_FNS
560 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
561 
562 } // namespace spirv
563 } // namespace mlir
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.