MLIR  20.0.0git
DeserializeOps.cpp
Go to the documentation of this file.
1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Location.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Debug.h"
23 #include <optional>
24 
25 using namespace mlir;
26 
27 #define DEBUG_TYPE "spirv-deserialization"
28 
29 //===----------------------------------------------------------------------===//
30 // Utility Functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Extracts the opcode from the given first word of a SPIR-V instruction.
34 static inline spirv::Opcode extractOpcode(uint32_t word) {
35  return static_cast<spirv::Opcode>(word & 0xffff);
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Instruction
40 //===----------------------------------------------------------------------===//
41 
42 Value spirv::Deserializer::getValue(uint32_t id) {
43  if (auto constInfo = getConstant(id)) {
44  // Materialize a `spirv.Constant` op at every use site.
45  return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
46  constInfo->first);
47  }
48  if (auto varOp = getGlobalVariable(id)) {
49  auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
50  unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
51  return addressOfOp.getPointer();
52  }
53  if (auto constOp = getSpecConstant(id)) {
54  auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
55  unknownLoc, constOp.getDefaultValue().getType(),
56  SymbolRefAttr::get(constOp.getOperation()));
57  return referenceOfOp.getReference();
58  }
59  if (auto constCompositeOp = getSpecConstantComposite(id)) {
60  auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
61  unknownLoc, constCompositeOp.getType(),
62  SymbolRefAttr::get(constCompositeOp.getOperation()));
63  return referenceOfOp.getReference();
64  }
65  if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
66  return materializeSpecConstantOperation(
67  id, specConstOperationInfo->enclodesOpcode,
68  specConstOperationInfo->resultTypeID,
69  specConstOperationInfo->enclosedOpOperands);
70  }
71  if (auto undef = getUndefType(id)) {
72  return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
73  }
74  return valueMap.lookup(id);
75 }
76 
77 LogicalResult spirv::Deserializer::sliceInstruction(
78  spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
79  std::optional<spirv::Opcode> expectedOpcode) {
80  auto binarySize = binary.size();
81  if (curOffset >= binarySize) {
82  return emitError(unknownLoc, "expected ")
83  << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
84  : "more")
85  << " instruction";
86  }
87 
88  // For each instruction, get its word count from the first word to slice it
89  // from the stream properly, and then dispatch to the instruction handler.
90 
91  uint32_t wordCount = binary[curOffset] >> 16;
92 
93  if (wordCount == 0)
94  return emitError(unknownLoc, "word count cannot be zero");
95 
96  uint32_t nextOffset = curOffset + wordCount;
97  if (nextOffset > binarySize)
98  return emitError(unknownLoc, "insufficient words for the last instruction");
99 
100  opcode = extractOpcode(binary[curOffset]);
101  operands = binary.slice(curOffset + 1, wordCount - 1);
102  curOffset = nextOffset;
103  return success();
104 }
105 
106 LogicalResult spirv::Deserializer::processInstruction(
107  spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
108  LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
109  << spirv::stringifyOpcode(opcode) << "\n");
110 
111  // First dispatch all the instructions whose opcode does not correspond to
112  // those that have a direct mirror in the SPIR-V dialect
113  switch (opcode) {
114  case spirv::Opcode::OpCapability:
115  return processCapability(operands);
116  case spirv::Opcode::OpExtension:
117  return processExtension(operands);
118  case spirv::Opcode::OpExtInst:
119  return processExtInst(operands);
120  case spirv::Opcode::OpExtInstImport:
121  return processExtInstImport(operands);
122  case spirv::Opcode::OpMemberName:
123  return processMemberName(operands);
124  case spirv::Opcode::OpMemoryModel:
125  return processMemoryModel(operands);
126  case spirv::Opcode::OpEntryPoint:
127  case spirv::Opcode::OpExecutionMode:
128  if (deferInstructions) {
129  deferredInstructions.emplace_back(opcode, operands);
130  return success();
131  }
132  break;
133  case spirv::Opcode::OpVariable:
134  if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
135  return processGlobalVariable(operands);
136  }
137  break;
138  case spirv::Opcode::OpLine:
139  return processDebugLine(operands);
140  case spirv::Opcode::OpNoLine:
141  clearDebugLine();
142  return success();
143  case spirv::Opcode::OpName:
144  return processName(operands);
145  case spirv::Opcode::OpString:
146  return processDebugString(operands);
147  case spirv::Opcode::OpModuleProcessed:
148  case spirv::Opcode::OpSource:
149  case spirv::Opcode::OpSourceContinued:
150  case spirv::Opcode::OpSourceExtension:
151  // TODO: This is debug information embedded in the binary which should be
152  // translated into the spirv.module.
153  return success();
154  case spirv::Opcode::OpTypeVoid:
155  case spirv::Opcode::OpTypeBool:
156  case spirv::Opcode::OpTypeInt:
157  case spirv::Opcode::OpTypeFloat:
158  case spirv::Opcode::OpTypeVector:
159  case spirv::Opcode::OpTypeMatrix:
160  case spirv::Opcode::OpTypeArray:
161  case spirv::Opcode::OpTypeFunction:
162  case spirv::Opcode::OpTypeImage:
163  case spirv::Opcode::OpTypeSampledImage:
164  case spirv::Opcode::OpTypeRuntimeArray:
165  case spirv::Opcode::OpTypeStruct:
166  case spirv::Opcode::OpTypePointer:
167  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
168  return processType(opcode, operands);
169  case spirv::Opcode::OpTypeForwardPointer:
170  return processTypeForwardPointer(operands);
171  case spirv::Opcode::OpTypeJointMatrixINTEL:
172  return processType(opcode, operands);
173  case spirv::Opcode::OpConstant:
174  return processConstant(operands, /*isSpec=*/false);
175  case spirv::Opcode::OpSpecConstant:
176  return processConstant(operands, /*isSpec=*/true);
177  case spirv::Opcode::OpConstantComposite:
178  return processConstantComposite(operands);
179  case spirv::Opcode::OpSpecConstantComposite:
180  return processSpecConstantComposite(operands);
181  case spirv::Opcode::OpSpecConstantOp:
182  return processSpecConstantOperation(operands);
183  case spirv::Opcode::OpConstantTrue:
184  return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
185  case spirv::Opcode::OpSpecConstantTrue:
186  return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
187  case spirv::Opcode::OpConstantFalse:
188  return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
189  case spirv::Opcode::OpSpecConstantFalse:
190  return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
191  case spirv::Opcode::OpConstantNull:
192  return processConstantNull(operands);
193  case spirv::Opcode::OpDecorate:
194  return processDecoration(operands);
195  case spirv::Opcode::OpMemberDecorate:
196  return processMemberDecoration(operands);
197  case spirv::Opcode::OpFunction:
198  return processFunction(operands);
199  case spirv::Opcode::OpLabel:
200  return processLabel(operands);
201  case spirv::Opcode::OpBranch:
202  return processBranch(operands);
203  case spirv::Opcode::OpBranchConditional:
204  return processBranchConditional(operands);
205  case spirv::Opcode::OpSelectionMerge:
206  return processSelectionMerge(operands);
207  case spirv::Opcode::OpLoopMerge:
208  return processLoopMerge(operands);
209  case spirv::Opcode::OpPhi:
210  return processPhi(operands);
211  case spirv::Opcode::OpUndef:
212  return processUndef(operands);
213  default:
214  break;
215  }
216  return dispatchToAutogenDeserialization(opcode, operands);
217 }
218 
219 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
220  ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
221  unsigned numOperands) {
222  SmallVector<Type, 1> resultTypes;
223  uint32_t valueID = 0;
224 
225  size_t wordIndex = 0;
226  if (hasResult) {
227  if (wordIndex >= words.size())
228  return emitError(unknownLoc,
229  "expected result type <id> while deserializing for ")
230  << opName;
231 
232  // Decode the type <id>
233  auto type = getType(words[wordIndex]);
234  if (!type)
235  return emitError(unknownLoc, "unknown type result <id>: ")
236  << words[wordIndex];
237  resultTypes.push_back(type);
238  ++wordIndex;
239 
240  // Decode the result <id>
241  if (wordIndex >= words.size())
242  return emitError(unknownLoc,
243  "expected result <id> while deserializing for ")
244  << opName;
245  valueID = words[wordIndex];
246  ++wordIndex;
247  }
248 
249  SmallVector<Value, 4> operands;
251 
252  // Decode operands
253  size_t operandIndex = 0;
254  for (; operandIndex < numOperands && wordIndex < words.size();
255  ++operandIndex, ++wordIndex) {
256  auto arg = getValue(words[wordIndex]);
257  if (!arg)
258  return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
259  operands.push_back(arg);
260  }
261  if (operandIndex != numOperands) {
262  return emitError(
263  unknownLoc,
264  "found less operands than expected when deserializing for ")
265  << opName << "; only " << operandIndex << " of " << numOperands
266  << " processed";
267  }
268  if (wordIndex != words.size()) {
269  return emitError(
270  unknownLoc,
271  "found more operands than expected when deserializing for ")
272  << opName << "; only " << wordIndex << " of " << words.size()
273  << " processed";
274  }
275 
276  // Attach attributes from decorations
277  if (decorations.count(valueID)) {
278  auto attrs = decorations[valueID].getAttrs();
279  attributes.append(attrs.begin(), attrs.end());
280  }
281 
282  // Create the op and update bookkeeping maps
283  Location loc = createFileLineColLoc(opBuilder);
284  OperationState opState(loc, opName);
285  opState.addOperands(operands);
286  if (hasResult)
287  opState.addTypes(resultTypes);
288  opState.addAttributes(attributes);
289  Operation *op = opBuilder.create(opState);
290  if (hasResult)
291  valueMap[valueID] = op->getResult(0);
292 
293  if (op->hasTrait<OpTrait::IsTerminator>())
294  clearDebugLine();
295 
296  return success();
297 }
298 
299 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
300  if (operands.size() != 2) {
301  return emitError(unknownLoc, "OpUndef instruction must have two operands");
302  }
303  auto type = getType(operands[0]);
304  if (!type) {
305  return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
306  }
307  undefMap[operands[1]] = type;
308  return success();
309 }
310 
311 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
312  if (operands.size() < 4) {
313  return emitError(unknownLoc,
314  "OpExtInst must have at least 4 operands, result type "
315  "<id>, result <id>, set <id> and instruction opcode");
316  }
317  if (!extendedInstSets.count(operands[2])) {
318  return emitError(unknownLoc, "undefined set <id> in OpExtInst");
319  }
320  SmallVector<uint32_t, 4> slicedOperands;
321  slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
322  slicedOperands.append(std::next(operands.begin(), 4), operands.end());
323  return dispatchToExtensionSetAutogenDeserialization(
324  extendedInstSets[operands[2]], operands[3], slicedOperands);
325 }
326 
327 namespace mlir {
328 namespace spirv {
329 
330 template <>
331 LogicalResult
332 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
333  unsigned wordIndex = 0;
334  if (wordIndex >= words.size()) {
335  return emitError(unknownLoc,
336  "missing Execution Model specification in OpEntryPoint");
337  }
338  auto execModel = spirv::ExecutionModelAttr::get(
339  context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
340  if (wordIndex >= words.size()) {
341  return emitError(unknownLoc, "missing <id> in OpEntryPoint");
342  }
343  // Get the function <id>
344  auto fnID = words[wordIndex++];
345  // Get the function name
346  auto fnName = decodeStringLiteral(words, wordIndex);
347  // Verify that the function <id> matches the fnName
348  auto parsedFunc = getFunction(fnID);
349  if (!parsedFunc) {
350  return emitError(unknownLoc, "no function matching <id> ") << fnID;
351  }
352  if (parsedFunc.getName() != fnName) {
353  // The deserializer uses "spirv_fn_<id>" as the function name if the input
354  // SPIR-V blob does not contain a name for it. We should use a more clear
355  // indication for such case rather than relying on naming details.
356  if (!parsedFunc.getName().starts_with("spirv_fn_"))
357  return emitError(unknownLoc,
358  "function name mismatch between OpEntryPoint "
359  "and OpFunction with <id> ")
360  << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
361  parsedFunc.setName(fnName);
362  }
363  SmallVector<Attribute, 4> interface;
364  while (wordIndex < words.size()) {
365  auto arg = getGlobalVariable(words[wordIndex]);
366  if (!arg) {
367  return emitError(unknownLoc, "undefined result <id> ")
368  << words[wordIndex] << " while decoding OpEntryPoint";
369  }
370  interface.push_back(SymbolRefAttr::get(arg.getOperation()));
371  wordIndex++;
372  }
373  opBuilder.create<spirv::EntryPointOp>(
374  unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
375  opBuilder.getArrayAttr(interface));
376  return success();
377 }
378 
379 template <>
380 LogicalResult
381 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
382  unsigned wordIndex = 0;
383  if (wordIndex >= words.size()) {
384  return emitError(unknownLoc,
385  "missing function result <id> in OpExecutionMode");
386  }
387  // Get the function <id> to get the name of the function
388  auto fnID = words[wordIndex++];
389  auto fn = getFunction(fnID);
390  if (!fn) {
391  return emitError(unknownLoc, "no function matching <id> ") << fnID;
392  }
393  // Get the Execution mode
394  if (wordIndex >= words.size()) {
395  return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
396  }
397  auto execMode = spirv::ExecutionModeAttr::get(
398  context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
399 
400  // Get the values
401  SmallVector<Attribute, 4> attrListElems;
402  while (wordIndex < words.size()) {
403  attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
404  }
405  auto values = opBuilder.getArrayAttr(attrListElems);
406  opBuilder.create<spirv::ExecutionModeOp>(
407  unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
408  execMode, values);
409  return success();
410 }
411 
412 template <>
413 LogicalResult
414 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
415  if (operands.size() < 3) {
416  return emitError(unknownLoc,
417  "OpFunctionCall must have at least 3 operands");
418  }
419 
420  Type resultType = getType(operands[0]);
421  if (!resultType) {
422  return emitError(unknownLoc, "undefined result type from <id> ")
423  << operands[0];
424  }
425 
426  // Use null type to mean no result type.
427  if (isVoidType(resultType))
428  resultType = nullptr;
429 
430  auto resultID = operands[1];
431  auto functionID = operands[2];
432 
433  auto functionName = getFunctionSymbol(functionID);
434 
435  SmallVector<Value, 4> arguments;
436  for (auto operand : llvm::drop_begin(operands, 3)) {
437  auto value = getValue(operand);
438  if (!value) {
439  return emitError(unknownLoc, "unknown <id> ")
440  << operand << " used by OpFunctionCall";
441  }
442  arguments.push_back(value);
443  }
444 
445  auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
446  unknownLoc, resultType,
447  SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
448 
449  if (resultType)
450  valueMap[resultID] = opFunctionCall.getResult(0);
451  return success();
452 }
453 
454 template <>
455 LogicalResult
456 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
457  SmallVector<Type, 1> resultTypes;
458  size_t wordIndex = 0;
459  SmallVector<Value, 4> operands;
461 
462  if (wordIndex < words.size()) {
463  auto arg = getValue(words[wordIndex]);
464 
465  if (!arg) {
466  return emitError(unknownLoc, "unknown result <id> : ")
467  << words[wordIndex];
468  }
469 
470  operands.push_back(arg);
471  wordIndex++;
472  }
473 
474  if (wordIndex < words.size()) {
475  auto arg = getValue(words[wordIndex]);
476 
477  if (!arg) {
478  return emitError(unknownLoc, "unknown result <id> : ")
479  << words[wordIndex];
480  }
481 
482  operands.push_back(arg);
483  wordIndex++;
484  }
485 
486  bool isAlignedAttr = false;
487 
488  if (wordIndex < words.size()) {
489  auto attrValue = words[wordIndex++];
490  auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
491  static_cast<spirv::MemoryAccess>(attrValue));
492  attributes.push_back(
493  opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
494  isAlignedAttr = (attrValue == 2);
495  }
496 
497  if (isAlignedAttr && wordIndex < words.size()) {
498  attributes.push_back(opBuilder.getNamedAttr(
499  "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
500  }
501 
502  if (wordIndex < words.size()) {
503  auto attrValue = words[wordIndex++];
504  auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
505  static_cast<spirv::MemoryAccess>(attrValue));
506  attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
507  }
508 
509  if (wordIndex < words.size()) {
510  attributes.push_back(opBuilder.getNamedAttr(
511  "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
512  }
513 
514  if (wordIndex != words.size()) {
515  return emitError(unknownLoc,
516  "found more operands than expected when deserializing "
517  "spirv::CopyMemoryOp, only ")
518  << wordIndex << " of " << words.size() << " processed";
519  }
520 
521  Location loc = createFileLineColLoc(opBuilder);
522  opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
523 
524  return success();
525 }
526 
527 template <>
528 LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
529  ArrayRef<uint32_t> words) {
530  if (words.size() != 4) {
531  return emitError(unknownLoc,
532  "expected 4 words in GenericCastToPtrExplicitOp"
533  " but got : ")
534  << words.size();
535  }
536  SmallVector<Type, 1> resultTypes;
537  SmallVector<Value, 4> operands;
538  uint32_t valueID = 0;
539  auto type = getType(words[0]);
540 
541  if (!type)
542  return emitError(unknownLoc, "unknown type result <id> : ") << words[0];
543  resultTypes.push_back(type);
544 
545  valueID = words[1];
546 
547  auto arg = getValue(words[2]);
548  if (!arg)
549  return emitError(unknownLoc, "unknown result <id> : ") << words[2];
550  operands.push_back(arg);
551 
552  Location loc = createFileLineColLoc(opBuilder);
553  Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>(
554  loc, resultTypes, operands);
555  valueMap[valueID] = op->getResult(0);
556  return success();
557 }
558 
559 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
560 // various Deserializer::processOp<...>() specializations.
561 #define GET_DESERIALIZATION_FNS
562 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
563 
564 } // namespace spirv
565 } // namespace mlir
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.