MLIR  14.0.0git
DeserializeOps.cpp
Go to the documentation of this file.
1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Location.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "spirv-deserialization"
26 
27 //===----------------------------------------------------------------------===//
28 // Utility Functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Extracts the opcode from the given first word of a SPIR-V instruction.
32 static inline spirv::Opcode extractOpcode(uint32_t word) {
33  return static_cast<spirv::Opcode>(word & 0xffff);
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Instruction
38 //===----------------------------------------------------------------------===//
39 
40 Value spirv::Deserializer::getValue(uint32_t id) {
41  if (auto constInfo = getConstant(id)) {
42  // Materialize a `spv.Constant` op at every use site.
43  return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
44  constInfo->first);
45  }
46  if (auto varOp = getGlobalVariable(id)) {
47  auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
48  unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
49  return addressOfOp.pointer();
50  }
51  if (auto constOp = getSpecConstant(id)) {
52  auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
53  unknownLoc, constOp.default_value().getType(),
54  SymbolRefAttr::get(constOp.getOperation()));
55  return referenceOfOp.reference();
56  }
57  if (auto constCompositeOp = getSpecConstantComposite(id)) {
58  auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
59  unknownLoc, constCompositeOp.type(),
60  SymbolRefAttr::get(constCompositeOp.getOperation()));
61  return referenceOfOp.reference();
62  }
63  if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
64  return materializeSpecConstantOperation(
65  id, specConstOperationInfo->enclodesOpcode,
66  specConstOperationInfo->resultTypeID,
67  specConstOperationInfo->enclosedOpOperands);
68  }
69  if (auto undef = getUndefType(id)) {
70  return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
71  }
72  return valueMap.lookup(id);
73 }
74 
76 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
77  ArrayRef<uint32_t> &operands,
78  Optional<spirv::Opcode> expectedOpcode) {
79  auto binarySize = binary.size();
80  if (curOffset >= binarySize) {
81  return emitError(unknownLoc, "expected ")
82  << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
83  : "more")
84  << " instruction";
85  }
86 
87  // For each instruction, get its word count from the first word to slice it
88  // from the stream properly, and then dispatch to the instruction handler.
89 
90  uint32_t wordCount = binary[curOffset] >> 16;
91 
92  if (wordCount == 0)
93  return emitError(unknownLoc, "word count cannot be zero");
94 
95  uint32_t nextOffset = curOffset + wordCount;
96  if (nextOffset > binarySize)
97  return emitError(unknownLoc, "insufficient words for the last instruction");
98 
99  opcode = extractOpcode(binary[curOffset]);
100  operands = binary.slice(curOffset + 1, wordCount - 1);
101  curOffset = nextOffset;
102  return success();
103 }
104 
105 LogicalResult spirv::Deserializer::processInstruction(
106  spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
107  LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
108  << spirv::stringifyOpcode(opcode) << "\n");
109 
110  // First dispatch all the instructions whose opcode does not correspond to
111  // those that have a direct mirror in the SPIR-V dialect
112  switch (opcode) {
113  case spirv::Opcode::OpCapability:
114  return processCapability(operands);
115  case spirv::Opcode::OpExtension:
116  return processExtension(operands);
117  case spirv::Opcode::OpExtInst:
118  return processExtInst(operands);
119  case spirv::Opcode::OpExtInstImport:
120  return processExtInstImport(operands);
121  case spirv::Opcode::OpMemberName:
122  return processMemberName(operands);
123  case spirv::Opcode::OpMemoryModel:
124  return processMemoryModel(operands);
125  case spirv::Opcode::OpEntryPoint:
126  case spirv::Opcode::OpExecutionMode:
127  if (deferInstructions) {
128  deferredInstructions.emplace_back(opcode, operands);
129  return success();
130  }
131  break;
132  case spirv::Opcode::OpVariable:
133  if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
134  return processGlobalVariable(operands);
135  }
136  break;
137  case spirv::Opcode::OpLine:
138  return processDebugLine(operands);
139  case spirv::Opcode::OpNoLine:
140  clearDebugLine();
141  return success();
142  case spirv::Opcode::OpName:
143  return processName(operands);
144  case spirv::Opcode::OpString:
145  return processDebugString(operands);
146  case spirv::Opcode::OpModuleProcessed:
147  case spirv::Opcode::OpSource:
148  case spirv::Opcode::OpSourceContinued:
149  case spirv::Opcode::OpSourceExtension:
150  // TODO: This is debug information embedded in the binary which should be
151  // translated into the spv.module.
152  return success();
153  case spirv::Opcode::OpTypeVoid:
154  case spirv::Opcode::OpTypeBool:
155  case spirv::Opcode::OpTypeInt:
156  case spirv::Opcode::OpTypeFloat:
157  case spirv::Opcode::OpTypeVector:
158  case spirv::Opcode::OpTypeMatrix:
159  case spirv::Opcode::OpTypeArray:
160  case spirv::Opcode::OpTypeFunction:
161  case spirv::Opcode::OpTypeImage:
162  case spirv::Opcode::OpTypeSampledImage:
163  case spirv::Opcode::OpTypeRuntimeArray:
164  case spirv::Opcode::OpTypeStruct:
165  case spirv::Opcode::OpTypePointer:
166  case spirv::Opcode::OpTypeCooperativeMatrixNV:
167  return processType(opcode, operands);
168  case spirv::Opcode::OpTypeForwardPointer:
169  return processTypeForwardPointer(operands);
170  case spirv::Opcode::OpConstant:
171  return processConstant(operands, /*isSpec=*/false);
172  case spirv::Opcode::OpSpecConstant:
173  return processConstant(operands, /*isSpec=*/true);
174  case spirv::Opcode::OpConstantComposite:
175  return processConstantComposite(operands);
176  case spirv::Opcode::OpSpecConstantComposite:
177  return processSpecConstantComposite(operands);
178  case spirv::Opcode::OpSpecConstantOp:
179  return processSpecConstantOperation(operands);
180  case spirv::Opcode::OpConstantTrue:
181  return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
182  case spirv::Opcode::OpSpecConstantTrue:
183  return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
184  case spirv::Opcode::OpConstantFalse:
185  return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
186  case spirv::Opcode::OpSpecConstantFalse:
187  return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
188  case spirv::Opcode::OpConstantNull:
189  return processConstantNull(operands);
190  case spirv::Opcode::OpDecorate:
191  return processDecoration(operands);
192  case spirv::Opcode::OpMemberDecorate:
193  return processMemberDecoration(operands);
194  case spirv::Opcode::OpFunction:
195  return processFunction(operands);
196  case spirv::Opcode::OpLabel:
197  return processLabel(operands);
198  case spirv::Opcode::OpBranch:
199  return processBranch(operands);
200  case spirv::Opcode::OpBranchConditional:
201  return processBranchConditional(operands);
202  case spirv::Opcode::OpSelectionMerge:
203  return processSelectionMerge(operands);
204  case spirv::Opcode::OpLoopMerge:
205  return processLoopMerge(operands);
206  case spirv::Opcode::OpPhi:
207  return processPhi(operands);
208  case spirv::Opcode::OpUndef:
209  return processUndef(operands);
210  default:
211  break;
212  }
213  return dispatchToAutogenDeserialization(opcode, operands);
214 }
215 
216 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
217  ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
218  unsigned numOperands) {
219  SmallVector<Type, 1> resultTypes;
220  uint32_t valueID = 0;
221 
222  size_t wordIndex = 0;
223  if (hasResult) {
224  if (wordIndex >= words.size())
225  return emitError(unknownLoc,
226  "expected result type <id> while deserializing for ")
227  << opName;
228 
229  // Decode the type <id>
230  auto type = getType(words[wordIndex]);
231  if (!type)
232  return emitError(unknownLoc, "unknown type result <id>: ")
233  << words[wordIndex];
234  resultTypes.push_back(type);
235  ++wordIndex;
236 
237  // Decode the result <id>
238  if (wordIndex >= words.size())
239  return emitError(unknownLoc,
240  "expected result <id> while deserializing for ")
241  << opName;
242  valueID = words[wordIndex];
243  ++wordIndex;
244  }
245 
246  SmallVector<Value, 4> operands;
248 
249  // Decode operands
250  size_t operandIndex = 0;
251  for (; operandIndex < numOperands && wordIndex < words.size();
252  ++operandIndex, ++wordIndex) {
253  auto arg = getValue(words[wordIndex]);
254  if (!arg)
255  return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
256  operands.push_back(arg);
257  }
258  if (operandIndex != numOperands) {
259  return emitError(
260  unknownLoc,
261  "found less operands than expected when deserializing for ")
262  << opName << "; only " << operandIndex << " of " << numOperands
263  << " processed";
264  }
265  if (wordIndex != words.size()) {
266  return emitError(
267  unknownLoc,
268  "found more operands than expected when deserializing for ")
269  << opName << "; only " << wordIndex << " of " << words.size()
270  << " processed";
271  }
272 
273  // Attach attributes from decorations
274  if (decorations.count(valueID)) {
275  auto attrs = decorations[valueID].getAttrs();
276  attributes.append(attrs.begin(), attrs.end());
277  }
278 
279  // Create the op and update bookkeeping maps
280  Location loc = createFileLineColLoc(opBuilder);
281  OperationState opState(loc, opName);
282  opState.addOperands(operands);
283  if (hasResult)
284  opState.addTypes(resultTypes);
285  opState.addAttributes(attributes);
286  Operation *op = opBuilder.createOperation(opState);
287  if (hasResult)
288  valueMap[valueID] = op->getResult(0);
289 
290  if (op->hasTrait<OpTrait::IsTerminator>())
291  clearDebugLine();
292 
293  return success();
294 }
295 
296 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
297  if (operands.size() != 2) {
298  return emitError(unknownLoc, "OpUndef instruction must have two operands");
299  }
300  auto type = getType(operands[0]);
301  if (!type) {
302  return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
303  }
304  undefMap[operands[1]] = type;
305  return success();
306 }
307 
308 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
309  if (operands.size() < 4) {
310  return emitError(unknownLoc,
311  "OpExtInst must have at least 4 operands, result type "
312  "<id>, result <id>, set <id> and instruction opcode");
313  }
314  if (!extendedInstSets.count(operands[2])) {
315  return emitError(unknownLoc, "undefined set <id> in OpExtInst");
316  }
317  SmallVector<uint32_t, 4> slicedOperands;
318  slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
319  slicedOperands.append(std::next(operands.begin(), 4), operands.end());
320  return dispatchToExtensionSetAutogenDeserialization(
321  extendedInstSets[operands[2]], operands[3], slicedOperands);
322 }
323 
324 namespace mlir {
325 namespace spirv {
326 
327 template <>
329 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
330  unsigned wordIndex = 0;
331  if (wordIndex >= words.size()) {
332  return emitError(unknownLoc,
333  "missing Execution Model specification in OpEntryPoint");
334  }
335  auto execModel = spirv::ExecutionModelAttr::get(
336  context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
337  if (wordIndex >= words.size()) {
338  return emitError(unknownLoc, "missing <id> in OpEntryPoint");
339  }
340  // Get the function <id>
341  auto fnID = words[wordIndex++];
342  // Get the function name
343  auto fnName = decodeStringLiteral(words, wordIndex);
344  // Verify that the function <id> matches the fnName
345  auto parsedFunc = getFunction(fnID);
346  if (!parsedFunc) {
347  return emitError(unknownLoc, "no function matching <id> ") << fnID;
348  }
349  if (parsedFunc.getName() != fnName) {
350  return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
351  "and OpFunction with <id> ")
352  << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
353  }
354  SmallVector<Attribute, 4> interface;
355  while (wordIndex < words.size()) {
356  auto arg = getGlobalVariable(words[wordIndex]);
357  if (!arg) {
358  return emitError(unknownLoc, "undefined result <id> ")
359  << words[wordIndex] << " while decoding OpEntryPoint";
360  }
361  interface.push_back(SymbolRefAttr::get(arg.getOperation()));
362  wordIndex++;
363  }
364  opBuilder.create<spirv::EntryPointOp>(
365  unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
366  opBuilder.getArrayAttr(interface));
367  return success();
368 }
369 
370 template <>
372 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
373  unsigned wordIndex = 0;
374  if (wordIndex >= words.size()) {
375  return emitError(unknownLoc,
376  "missing function result <id> in OpExecutionMode");
377  }
378  // Get the function <id> to get the name of the function
379  auto fnID = words[wordIndex++];
380  auto fn = getFunction(fnID);
381  if (!fn) {
382  return emitError(unknownLoc, "no function matching <id> ") << fnID;
383  }
384  // Get the Execution mode
385  if (wordIndex >= words.size()) {
386  return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
387  }
388  auto execMode = spirv::ExecutionModeAttr::get(
389  context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
390 
391  // Get the values
392  SmallVector<Attribute, 4> attrListElems;
393  while (wordIndex < words.size()) {
394  attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
395  }
396  auto values = opBuilder.getArrayAttr(attrListElems);
397  opBuilder.create<spirv::ExecutionModeOp>(
398  unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
399  execMode, values);
400  return success();
401 }
402 
403 template <>
405 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
406  if (operands.size() != 3) {
407  return emitError(
408  unknownLoc,
409  "OpControlBarrier must have execution scope <id>, memory scope <id> "
410  "and memory semantics <id>");
411  }
412 
414  for (auto operand : operands) {
415  auto argAttr = getConstantInt(operand);
416  if (!argAttr) {
417  return emitError(unknownLoc,
418  "expected 32-bit integer constant from <id> ")
419  << operand << " for OpControlBarrier";
420  }
421  argAttrs.push_back(argAttr);
422  }
423 
424  opBuilder.create<spirv::ControlBarrierOp>(
425  unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
426  argAttrs[1].cast<spirv::ScopeAttr>(),
427  argAttrs[2].cast<spirv::MemorySemanticsAttr>());
428 
429  return success();
430 }
431 
432 template <>
434 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
435  if (operands.size() < 3) {
436  return emitError(unknownLoc,
437  "OpFunctionCall must have at least 3 operands");
438  }
439 
440  Type resultType = getType(operands[0]);
441  if (!resultType) {
442  return emitError(unknownLoc, "undefined result type from <id> ")
443  << operands[0];
444  }
445 
446  // Use null type to mean no result type.
447  if (isVoidType(resultType))
448  resultType = nullptr;
449 
450  auto resultID = operands[1];
451  auto functionID = operands[2];
452 
453  auto functionName = getFunctionSymbol(functionID);
454 
455  SmallVector<Value, 4> arguments;
456  for (auto operand : llvm::drop_begin(operands, 3)) {
457  auto value = getValue(operand);
458  if (!value) {
459  return emitError(unknownLoc, "unknown <id> ")
460  << operand << " used by OpFunctionCall";
461  }
462  arguments.push_back(value);
463  }
464 
465  auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
466  unknownLoc, resultType,
467  SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
468 
469  if (resultType)
470  valueMap[resultID] = opFunctionCall.getResult(0);
471  return success();
472 }
473 
474 template <>
476 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
477  if (operands.size() != 2) {
478  return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
479  "and memory semantics <id>");
480  }
481 
483  for (auto operand : operands) {
484  auto argAttr = getConstantInt(operand);
485  if (!argAttr) {
486  return emitError(unknownLoc,
487  "expected 32-bit integer constant from <id> ")
488  << operand << " for OpMemoryBarrier";
489  }
490  argAttrs.push_back(argAttr);
491  }
492 
493  opBuilder.create<spirv::MemoryBarrierOp>(
494  unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
495  argAttrs[1].cast<spirv::MemorySemanticsAttr>());
496  return success();
497 }
498 
499 template <>
501 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
502  SmallVector<Type, 1> resultTypes;
503  size_t wordIndex = 0;
504  SmallVector<Value, 4> operands;
506 
507  if (wordIndex < words.size()) {
508  auto arg = getValue(words[wordIndex]);
509 
510  if (!arg) {
511  return emitError(unknownLoc, "unknown result <id> : ")
512  << words[wordIndex];
513  }
514 
515  operands.push_back(arg);
516  wordIndex++;
517  }
518 
519  if (wordIndex < words.size()) {
520  auto arg = getValue(words[wordIndex]);
521 
522  if (!arg) {
523  return emitError(unknownLoc, "unknown result <id> : ")
524  << words[wordIndex];
525  }
526 
527  operands.push_back(arg);
528  wordIndex++;
529  }
530 
531  bool isAlignedAttr = false;
532 
533  if (wordIndex < words.size()) {
534  auto attrValue = words[wordIndex++];
535  attributes.push_back(opBuilder.getNamedAttr(
536  "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
537  isAlignedAttr = (attrValue == 2);
538  }
539 
540  if (isAlignedAttr && wordIndex < words.size()) {
541  attributes.push_back(opBuilder.getNamedAttr(
542  "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
543  }
544 
545  if (wordIndex < words.size()) {
546  attributes.push_back(opBuilder.getNamedAttr(
547  "source_memory_access",
548  opBuilder.getI32IntegerAttr(words[wordIndex++])));
549  }
550 
551  if (wordIndex < words.size()) {
552  attributes.push_back(opBuilder.getNamedAttr(
553  "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
554  }
555 
556  if (wordIndex != words.size()) {
557  return emitError(unknownLoc,
558  "found more operands than expected when deserializing "
559  "spirv::CopyMemoryOp, only ")
560  << wordIndex << " of " << words.size() << " processed";
561  }
562 
563  Location loc = createFileLineColLoc(opBuilder);
564  opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
565 
566  return success();
567 }
568 
569 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
570 // various Deserializer::processOp<...>() specializations.
571 #define GET_DESERIALIZATION_FNS
572 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
573 
574 } // namespace spirv
575 } // namespace mlir
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:676
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:379
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
void addOperands(ValueRange newOperands)
Operation * createOperation(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:470
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
void addTypes(ArrayRef< Type > newTypes)
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
static spirv::Opcode extractOpcode(uint32_t word)
Extracts the opcode from the given first word of a SPIR-V instruction.