MLIR  15.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context)
53  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54  module(createModuleOp()), opBuilder(module->getRegion())
55 #ifndef NDEBUG
56  ,
57  logger(llvm::dbgs())
58 #endif
59 {
60 }
61 
63  LLVM_DEBUG({
64  logger.resetIndent();
65  logger.startLine()
66  << "//+++---------- start deserialization ----------+++//\n";
67  });
68 
69  if (failed(processHeader()))
70  return failure();
71 
72  spirv::Opcode opcode = spirv::Opcode::OpNop;
73  ArrayRef<uint32_t> operands;
74  auto binarySize = binary.size();
75  while (curOffset < binarySize) {
76  // Slice the next instruction out and populate `opcode` and `operands`.
77  // Internally this also updates `curOffset`.
78  if (failed(sliceInstruction(opcode, operands)))
79  return failure();
80 
81  if (failed(processInstruction(opcode, operands)))
82  return failure();
83  }
84 
85  assert(curOffset == binarySize &&
86  "deserializer should never index beyond the binary end");
87 
88  for (auto &deferred : deferredInstructions) {
89  if (failed(processInstruction(deferred.first, deferred.second, false))) {
90  return failure();
91  }
92  }
93 
94  attachVCETriple();
95 
96  LLVM_DEBUG(logger.startLine()
97  << "//+++-------- completed deserialization --------+++//\n");
98  return success();
99 }
100 
102  return std::move(module);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Module structure
107 //===----------------------------------------------------------------------===//
108 
109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
110  OpBuilder builder(context);
111  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112  spirv::ModuleOp::build(builder, state);
113  return cast<spirv::ModuleOp>(Operation::create(state));
114 }
115 
116 LogicalResult spirv::Deserializer::processHeader() {
117  if (binary.size() < spirv::kHeaderWordCount)
118  return emitError(unknownLoc,
119  "SPIR-V binary module must have a 5-word header");
120 
121  if (binary[0] != spirv::kMagicNumber)
122  return emitError(unknownLoc, "incorrect magic number");
123 
124  // Version number bytes: 0 | major number | minor number | 0
125  uint32_t majorVersion = (binary[1] << 8) >> 24;
126  uint32_t minorVersion = (binary[1] << 16) >> 24;
127  if (majorVersion == 1) {
128  switch (minorVersion) {
129 #define MIN_VERSION_CASE(v) \
130  case v: \
131  version = spirv::Version::V_1_##v; \
132  break
133 
134  MIN_VERSION_CASE(0);
135  MIN_VERSION_CASE(1);
136  MIN_VERSION_CASE(2);
137  MIN_VERSION_CASE(3);
138  MIN_VERSION_CASE(4);
139  MIN_VERSION_CASE(5);
140 #undef MIN_VERSION_CASE
141  default:
142  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
143  << minorVersion;
144  }
145  } else {
146  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
147  << majorVersion;
148  }
149 
150  // TODO: generator number, bound, schema
151  curOffset = spirv::kHeaderWordCount;
152  return success();
153 }
154 
156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
157  if (operands.size() != 1)
158  return emitError(unknownLoc, "OpMemoryModel must have one parameter");
159 
160  auto cap = spirv::symbolizeCapability(operands[0]);
161  if (!cap)
162  return emitError(unknownLoc, "unknown capability: ") << operands[0];
163 
164  capabilities.insert(*cap);
165  return success();
166 }
167 
168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
169  if (words.empty()) {
170  return emitError(
171  unknownLoc,
172  "OpExtension must have a literal string for the extension name");
173  }
174 
175  unsigned wordIndex = 0;
176  StringRef extName = decodeStringLiteral(words, wordIndex);
177  if (wordIndex != words.size())
178  return emitError(unknownLoc,
179  "unexpected trailing words in OpExtension instruction");
180  auto ext = spirv::symbolizeExtension(extName);
181  if (!ext)
182  return emitError(unknownLoc, "unknown extension: ") << extName;
183 
184  extensions.insert(*ext);
185  return success();
186 }
187 
189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
190  if (words.size() < 2) {
191  return emitError(unknownLoc,
192  "OpExtInstImport must have a result <id> and a literal "
193  "string for the extended instruction set name");
194  }
195 
196  unsigned wordIndex = 1;
197  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
198  if (wordIndex != words.size()) {
199  return emitError(unknownLoc,
200  "unexpected trailing words in OpExtInstImport");
201  }
202  return success();
203 }
204 
205 void spirv::Deserializer::attachVCETriple() {
206  (*module)->setAttr(
207  spirv::ModuleOp::getVCETripleAttrName(),
208  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
209  extensions.getArrayRef(), context));
210 }
211 
213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
214  if (operands.size() != 2)
215  return emitError(unknownLoc, "OpMemoryModel must have two operands");
216 
217  (*module)->setAttr(
218  "addressing_model",
219  opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
220  (*module)->setAttr(
221  "memory_model",
222  opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
223 
224  return success();
225 }
226 
227 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
228  // TODO: This function should also be auto-generated. For now, since only a
229  // few decorations are processed/handled in a meaningful manner, going with a
230  // manual implementation.
231  if (words.size() < 2) {
232  return emitError(
233  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
234  }
235  auto decorationName =
236  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
237  if (decorationName.empty()) {
238  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
239  }
240  auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
241  auto symbol = opBuilder.getStringAttr(attrName);
242  switch (static_cast<spirv::Decoration>(words[1])) {
243  case spirv::Decoration::DescriptorSet:
244  case spirv::Decoration::Binding:
245  if (words.size() != 3) {
246  return emitError(unknownLoc, "OpDecorate with ")
247  << decorationName << " needs a single integer literal";
248  }
249  decorations[words[0]].set(
250  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
251  break;
252  case spirv::Decoration::BuiltIn:
253  if (words.size() != 3) {
254  return emitError(unknownLoc, "OpDecorate with ")
255  << decorationName << " needs a single integer literal";
256  }
257  decorations[words[0]].set(
258  symbol, opBuilder.getStringAttr(
259  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
260  break;
261  case spirv::Decoration::ArrayStride:
262  if (words.size() != 3) {
263  return emitError(unknownLoc, "OpDecorate with ")
264  << decorationName << " needs a single integer literal";
265  }
266  typeDecorations[words[0]] = words[2];
267  break;
268  case spirv::Decoration::Aliased:
269  case spirv::Decoration::Block:
270  case spirv::Decoration::BufferBlock:
271  case spirv::Decoration::Flat:
272  case spirv::Decoration::NonReadable:
273  case spirv::Decoration::NonWritable:
274  case spirv::Decoration::NoPerspective:
275  case spirv::Decoration::Restrict:
276  case spirv::Decoration::RelaxedPrecision:
277  if (words.size() != 2) {
278  return emitError(unknownLoc, "OpDecoration with ")
279  << decorationName << "needs a single target <id>";
280  }
281  // Block decoration does not affect spv.struct type, but is still stored for
282  // verification.
283  // TODO: Update StructType to contain this information since
284  // it is needed for many validation rules.
285  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
286  break;
287  case spirv::Decoration::Location:
288  case spirv::Decoration::SpecId:
289  if (words.size() != 3) {
290  return emitError(unknownLoc, "OpDecoration with ")
291  << decorationName << "needs a single integer literal";
292  }
293  decorations[words[0]].set(
294  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
295  break;
296  default:
297  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
298  }
299  return success();
300 }
301 
303 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
304  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
305  if (words.size() < 3) {
306  return emitError(unknownLoc,
307  "OpMemberDecorate must have at least 3 operands");
308  }
309 
310  auto decoration = static_cast<spirv::Decoration>(words[2]);
311  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
312  return emitError(unknownLoc,
313  " missing offset specification in OpMemberDecorate with "
314  "Offset decoration");
315  }
316  ArrayRef<uint32_t> decorationOperands;
317  if (words.size() > 3) {
318  decorationOperands = words.slice(3);
319  }
320  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
321  return success();
322 }
323 
324 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
325  if (words.size() < 3) {
326  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
327  }
328  unsigned wordIndex = 2;
329  auto name = decodeStringLiteral(words, wordIndex);
330  if (wordIndex != words.size()) {
331  return emitError(unknownLoc,
332  "unexpected trailing words in OpMemberName instruction");
333  }
334  memberNameMap[words[0]][words[1]] = name;
335  return success();
336 }
337 
339 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
340  if (curFunction) {
341  return emitError(unknownLoc, "found function inside function");
342  }
343 
344  // Get the result type
345  if (operands.size() != 4) {
346  return emitError(unknownLoc, "OpFunction must have 4 parameters");
347  }
348  Type resultType = getType(operands[0]);
349  if (!resultType) {
350  return emitError(unknownLoc, "undefined result type from <id> ")
351  << operands[0];
352  }
353 
354  uint32_t fnID = operands[1];
355  if (funcMap.count(fnID)) {
356  return emitError(unknownLoc, "duplicate function definition/declaration");
357  }
358 
359  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
360  if (!fnControl) {
361  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
362  }
363 
364  Type fnType = getType(operands[3]);
365  if (!fnType || !fnType.isa<FunctionType>()) {
366  return emitError(unknownLoc, "unknown function type from <id> ")
367  << operands[3];
368  }
369  auto functionType = fnType.cast<FunctionType>();
370 
371  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
372  (functionType.getNumResults() == 1 &&
373  functionType.getResult(0) != resultType)) {
374  return emitError(unknownLoc, "mismatch in function type ")
375  << functionType << " and return type " << resultType << " specified";
376  }
377 
378  std::string fnName = getFunctionSymbol(fnID);
379  auto funcOp = opBuilder.create<spirv::FuncOp>(
380  unknownLoc, fnName, functionType, fnControl.getValue());
381  curFunction = funcMap[fnID] = funcOp;
382  auto *entryBlock = funcOp.addEntryBlock();
383  LLVM_DEBUG({
384  logger.startLine()
385  << "//===-------------------------------------------===//\n";
386  logger.startLine() << "[fn] name: " << fnName << "\n";
387  logger.startLine() << "[fn] type: " << fnType << "\n";
388  logger.startLine() << "[fn] ID: " << fnID << "\n";
389  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
390  logger.indent();
391  });
392 
393  // Parse the op argument instructions
394  if (functionType.getNumInputs()) {
395  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
396  auto argType = functionType.getInput(i);
397  spirv::Opcode opcode = spirv::Opcode::OpNop;
398  ArrayRef<uint32_t> operands;
399  if (failed(sliceInstruction(opcode, operands,
400  spirv::Opcode::OpFunctionParameter))) {
401  return failure();
402  }
403  if (opcode != spirv::Opcode::OpFunctionParameter) {
404  return emitError(
405  unknownLoc,
406  "missing OpFunctionParameter instruction for argument ")
407  << i;
408  }
409  if (operands.size() != 2) {
410  return emitError(
411  unknownLoc,
412  "expected result type and result <id> for OpFunctionParameter");
413  }
414  auto argDefinedType = getType(operands[0]);
415  if (!argDefinedType || argDefinedType != argType) {
416  return emitError(unknownLoc,
417  "mismatch in argument type between function type "
418  "definition ")
419  << functionType << " and argument type definition "
420  << argDefinedType << " at argument " << i;
421  }
422  if (getValue(operands[1])) {
423  return emitError(unknownLoc, "duplicate definition of result <id> ")
424  << operands[1];
425  }
426  auto argValue = funcOp.getArgument(i);
427  valueMap[operands[1]] = argValue;
428  }
429  }
430 
431  // RAII guard to reset the insertion point to the module's region after
432  // deserializing the body of this function.
433  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
434 
435  spirv::Opcode opcode = spirv::Opcode::OpNop;
436  ArrayRef<uint32_t> instOperands;
437 
438  // Special handling for the entry block. We need to make sure it starts with
439  // an OpLabel instruction. The entry block takes the same parameters as the
440  // function. All other blocks do not take any parameter. We have already
441  // created the entry block, here we need to register it to the correct label
442  // <id>.
443  if (failed(sliceInstruction(opcode, instOperands,
444  spirv::Opcode::OpFunctionEnd))) {
445  return failure();
446  }
447  if (opcode == spirv::Opcode::OpFunctionEnd) {
448  return processFunctionEnd(instOperands);
449  }
450  if (opcode != spirv::Opcode::OpLabel) {
451  return emitError(unknownLoc, "a basic block must start with OpLabel");
452  }
453  if (instOperands.size() != 1) {
454  return emitError(unknownLoc, "OpLabel should only have result <id>");
455  }
456  blockMap[instOperands[0]] = entryBlock;
457  if (failed(processLabel(instOperands))) {
458  return failure();
459  }
460 
461  // Then process all the other instructions in the function until we hit
462  // OpFunctionEnd.
463  while (succeeded(sliceInstruction(opcode, instOperands,
464  spirv::Opcode::OpFunctionEnd)) &&
465  opcode != spirv::Opcode::OpFunctionEnd) {
466  if (failed(processInstruction(opcode, instOperands))) {
467  return failure();
468  }
469  }
470  if (opcode != spirv::Opcode::OpFunctionEnd) {
471  return failure();
472  }
473 
474  return processFunctionEnd(instOperands);
475 }
476 
478 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
479  // Process OpFunctionEnd.
480  if (!operands.empty()) {
481  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
482  }
483 
484  // Wire up block arguments from OpPhi instructions.
485  // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops.
486  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
487  return failure();
488  }
489 
490  curBlock = nullptr;
491  curFunction = llvm::None;
492 
493  LLVM_DEBUG({
494  logger.unindent();
495  logger.startLine()
496  << "//===-------------------------------------------===//\n";
497  });
498  return success();
499 }
500 
502 spirv::Deserializer::getConstant(uint32_t id) {
503  auto constIt = constantMap.find(id);
504  if (constIt == constantMap.end())
505  return llvm::None;
506  return constIt->getSecond();
507 }
508 
510 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
511  auto constIt = specConstOperationMap.find(id);
512  if (constIt == specConstOperationMap.end())
513  return llvm::None;
514  return constIt->getSecond();
515 }
516 
517 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
518  auto funcName = nameMap.lookup(id).str();
519  if (funcName.empty()) {
520  funcName = "spirv_fn_" + std::to_string(id);
521  }
522  return funcName;
523 }
524 
525 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
526  auto constName = nameMap.lookup(id).str();
527  if (constName.empty()) {
528  constName = "spirv_spec_const_" + std::to_string(id);
529  }
530  return constName;
531 }
532 
533 spirv::SpecConstantOp
534 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
535  Attribute defaultValue) {
536  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
537  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
538  defaultValue);
539  if (decorations.count(resultID)) {
540  for (auto attr : decorations[resultID].getAttrs())
541  op->setAttr(attr.getName(), attr.getValue());
542  }
543  specConstMap[resultID] = op;
544  return op;
545 }
546 
548 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
549  unsigned wordIndex = 0;
550  if (operands.size() < 3) {
551  return emitError(
552  unknownLoc,
553  "OpVariable needs at least 3 operands, type, <id> and storage class");
554  }
555 
556  // Result Type.
557  auto type = getType(operands[wordIndex]);
558  if (!type) {
559  return emitError(unknownLoc, "unknown result type <id> : ")
560  << operands[wordIndex];
561  }
562  auto ptrType = type.dyn_cast<spirv::PointerType>();
563  if (!ptrType) {
564  return emitError(unknownLoc,
565  "expected a result type <id> to be a spv.ptr, found : ")
566  << type;
567  }
568  wordIndex++;
569 
570  // Result <id>.
571  auto variableID = operands[wordIndex];
572  auto variableName = nameMap.lookup(variableID).str();
573  if (variableName.empty()) {
574  variableName = "spirv_var_" + std::to_string(variableID);
575  }
576  wordIndex++;
577 
578  // Storage class.
579  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
580  if (ptrType.getStorageClass() != storageClass) {
581  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
582  << type << " and that specified in OpVariable instruction : "
583  << stringifyStorageClass(storageClass);
584  }
585  wordIndex++;
586 
587  // Initializer.
588  FlatSymbolRefAttr initializer = nullptr;
589  if (wordIndex < operands.size()) {
590  auto initializerOp = getGlobalVariable(operands[wordIndex]);
591  if (!initializerOp) {
592  return emitError(unknownLoc, "unknown <id> ")
593  << operands[wordIndex] << "used as initializer";
594  }
595  wordIndex++;
596  initializer = SymbolRefAttr::get(initializerOp.getOperation());
597  }
598  if (wordIndex != operands.size()) {
599  return emitError(unknownLoc,
600  "found more operands than expected when deserializing "
601  "OpVariable instruction, only ")
602  << wordIndex << " of " << operands.size() << " processed";
603  }
604  auto loc = createFileLineColLoc(opBuilder);
605  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
606  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
607  initializer);
608 
609  // Decorations.
610  if (decorations.count(variableID)) {
611  for (auto attr : decorations[variableID].getAttrs())
612  varOp->setAttr(attr.getName(), attr.getValue());
613  }
614  globalVariableMap[variableID] = varOp;
615  return success();
616 }
617 
618 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
619  auto constInfo = getConstant(id);
620  if (!constInfo) {
621  return nullptr;
622  }
623  return constInfo->first.dyn_cast<IntegerAttr>();
624 }
625 
626 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
627  if (operands.size() < 2) {
628  return emitError(unknownLoc, "OpName needs at least 2 operands");
629  }
630  if (!nameMap.lookup(operands[0]).empty()) {
631  return emitError(unknownLoc, "duplicate name found for result <id> ")
632  << operands[0];
633  }
634  unsigned wordIndex = 1;
635  StringRef name = decodeStringLiteral(operands, wordIndex);
636  if (wordIndex != operands.size()) {
637  return emitError(unknownLoc,
638  "unexpected trailing words in OpName instruction");
639  }
640  nameMap[operands[0]] = name;
641  return success();
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // Type
646 //===----------------------------------------------------------------------===//
647 
648 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
649  ArrayRef<uint32_t> operands) {
650  if (operands.empty()) {
651  return emitError(unknownLoc, "type instruction with opcode ")
652  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
653  }
654 
655  /// TODO: Types might be forward declared in some instructions and need to be
656  /// handled appropriately.
657  if (typeMap.count(operands[0])) {
658  return emitError(unknownLoc, "duplicate definition for result <id> ")
659  << operands[0];
660  }
661 
662  switch (opcode) {
663  case spirv::Opcode::OpTypeVoid:
664  if (operands.size() != 1)
665  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
666  typeMap[operands[0]] = opBuilder.getNoneType();
667  break;
668  case spirv::Opcode::OpTypeBool:
669  if (operands.size() != 1)
670  return emitError(unknownLoc, "OpTypeBool must have no parameters");
671  typeMap[operands[0]] = opBuilder.getI1Type();
672  break;
673  case spirv::Opcode::OpTypeInt: {
674  if (operands.size() != 3)
675  return emitError(
676  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
677 
678  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
679  // to preserve or validate.
680  // 0 indicates unsigned, or no signedness semantics
681  // 1 indicates signed semantics."
682  //
683  // So we cannot differentiate signless and unsigned integers; always use
684  // signless semantics for such cases.
685  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
686  : IntegerType::SignednessSemantics::Signless;
687  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
688  } break;
689  case spirv::Opcode::OpTypeFloat: {
690  if (operands.size() != 2)
691  return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
692 
693  Type floatTy;
694  switch (operands[1]) {
695  case 16:
696  floatTy = opBuilder.getF16Type();
697  break;
698  case 32:
699  floatTy = opBuilder.getF32Type();
700  break;
701  case 64:
702  floatTy = opBuilder.getF64Type();
703  break;
704  default:
705  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
706  << operands[1];
707  }
708  typeMap[operands[0]] = floatTy;
709  } break;
710  case spirv::Opcode::OpTypeVector: {
711  if (operands.size() != 3) {
712  return emitError(
713  unknownLoc,
714  "OpTypeVector must have element type and count parameters");
715  }
716  Type elementTy = getType(operands[1]);
717  if (!elementTy) {
718  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
719  << operands[1];
720  }
721  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
722  } break;
723  case spirv::Opcode::OpTypePointer: {
724  return processOpTypePointer(operands);
725  } break;
726  case spirv::Opcode::OpTypeArray:
727  return processArrayType(operands);
728  case spirv::Opcode::OpTypeCooperativeMatrixNV:
729  return processCooperativeMatrixType(operands);
730  case spirv::Opcode::OpTypeFunction:
731  return processFunctionType(operands);
732  case spirv::Opcode::OpTypeImage:
733  return processImageType(operands);
734  case spirv::Opcode::OpTypeSampledImage:
735  return processSampledImageType(operands);
736  case spirv::Opcode::OpTypeRuntimeArray:
737  return processRuntimeArrayType(operands);
738  case spirv::Opcode::OpTypeStruct:
739  return processStructType(operands);
740  case spirv::Opcode::OpTypeMatrix:
741  return processMatrixType(operands);
742  default:
743  return emitError(unknownLoc, "unhandled type instruction");
744  }
745  return success();
746 }
747 
749 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
750  if (operands.size() != 3)
751  return emitError(unknownLoc, "OpTypePointer must have two parameters");
752 
753  auto pointeeType = getType(operands[2]);
754  if (!pointeeType)
755  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
756  << operands[2];
757 
758  uint32_t typePointerID = operands[0];
759  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
760  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
761 
762  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
763  deferredStructIt != std::end(deferredStructTypesInfos);) {
764  for (auto *unresolvedMemberIt =
765  std::begin(deferredStructIt->unresolvedMemberTypes);
766  unresolvedMemberIt !=
767  std::end(deferredStructIt->unresolvedMemberTypes);) {
768  if (unresolvedMemberIt->first == typePointerID) {
769  // The newly constructed pointer type can resolve one of the
770  // deferred struct type members; update the memberTypes list and
771  // clean the unresolvedMemberTypes list accordingly.
772  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
773  typeMap[typePointerID];
774  unresolvedMemberIt =
775  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
776  } else {
777  ++unresolvedMemberIt;
778  }
779  }
780 
781  if (deferredStructIt->unresolvedMemberTypes.empty()) {
782  // All deferred struct type members are now resolved, set the struct body.
783  auto structType = deferredStructIt->deferredStructType;
784 
785  assert(structType && "expected a spirv::StructType");
786  assert(structType.isIdentified() && "expected an indentified struct");
787 
788  if (failed(structType.trySetBody(
789  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
790  deferredStructIt->memberDecorationsInfo)))
791  return failure();
792 
793  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
794  } else {
795  ++deferredStructIt;
796  }
797  }
798 
799  return success();
800 }
801 
803 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
804  if (operands.size() != 3) {
805  return emitError(unknownLoc,
806  "OpTypeArray must have element type and count parameters");
807  }
808 
809  Type elementTy = getType(operands[1]);
810  if (!elementTy) {
811  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
812  << operands[1];
813  }
814 
815  unsigned count = 0;
816  // TODO: The count can also come frome a specialization constant.
817  auto countInfo = getConstant(operands[2]);
818  if (!countInfo) {
819  return emitError(unknownLoc, "OpTypeArray count <id> ")
820  << operands[2] << "can only come from normal constant right now";
821  }
822 
823  if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
824  count = intVal.getValue().getZExtValue();
825  } else {
826  return emitError(unknownLoc, "OpTypeArray count must come from a "
827  "scalar integer constant instruction");
828  }
829 
830  typeMap[operands[0]] = spirv::ArrayType::get(
831  elementTy, count, typeDecorations.lookup(operands[0]));
832  return success();
833 }
834 
836 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
837  assert(!operands.empty() && "No operands for processing function type");
838  if (operands.size() == 1) {
839  return emitError(unknownLoc, "missing return type for OpTypeFunction");
840  }
841  auto returnType = getType(operands[1]);
842  if (!returnType) {
843  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
844  }
845  SmallVector<Type, 1> argTypes;
846  for (size_t i = 2, e = operands.size(); i < e; ++i) {
847  auto ty = getType(operands[i]);
848  if (!ty) {
849  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
850  }
851  argTypes.push_back(ty);
852  }
853  ArrayRef<Type> returnTypes;
854  if (!isVoidType(returnType)) {
855  returnTypes = llvm::makeArrayRef(returnType);
856  }
857  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
858  return success();
859 }
860 
862 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
863  if (operands.size() != 5) {
864  return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
865  "type and row x column parameters");
866  }
867 
868  Type elementTy = getType(operands[1]);
869  if (!elementTy) {
870  return emitError(unknownLoc,
871  "OpTypeCooperativeMatrix references undefined <id> ")
872  << operands[1];
873  }
874 
875  auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
876  if (!scope) {
877  return emitError(unknownLoc,
878  "OpTypeCooperativeMatrix references undefined scope <id> ")
879  << operands[2];
880  }
881 
882  unsigned rows = getConstantInt(operands[3]).getInt();
883  unsigned columns = getConstantInt(operands[4]).getInt();
884 
885  typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
886  elementTy, scope.getValue(), rows, columns);
887  return success();
888 }
889 
891 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
892  if (operands.size() != 2) {
893  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
894  }
895  Type memberType = getType(operands[1]);
896  if (!memberType) {
897  return emitError(unknownLoc,
898  "OpTypeRuntimeArray references undefined <id> ")
899  << operands[1];
900  }
901  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
902  memberType, typeDecorations.lookup(operands[0]));
903  return success();
904 }
905 
907 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
908  // TODO: Find a way to handle identified structs when debug info is stripped.
909 
910  if (operands.empty()) {
911  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
912  }
913 
914  if (operands.size() == 1) {
915  // Handle empty struct.
916  typeMap[operands[0]] =
917  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
918  return success();
919  }
920 
921  // First element is operand ID, second element is member index in the struct.
922  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
923  SmallVector<Type, 4> memberTypes;
924 
925  for (auto op : llvm::drop_begin(operands, 1)) {
926  Type memberType = getType(op);
927  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
928 
929  if (!memberType && !typeForwardPtr)
930  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
931  << op;
932 
933  if (!memberType)
934  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
935 
936  memberTypes.push_back(memberType);
937  }
938 
941  if (memberDecorationMap.count(operands[0])) {
942  auto &allMemberDecorations = memberDecorationMap[operands[0]];
943  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
944  if (allMemberDecorations.count(memberIndex)) {
945  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
946  // Check for offset.
947  if (memberDecoration.first == spirv::Decoration::Offset) {
948  // If offset info is empty, resize to the number of members;
949  if (offsetInfo.empty()) {
950  offsetInfo.resize(memberTypes.size());
951  }
952  offsetInfo[memberIndex] = memberDecoration.second[0];
953  } else {
954  if (!memberDecoration.second.empty()) {
955  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
956  memberDecoration.first,
957  memberDecoration.second[0]);
958  } else {
959  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
960  memberDecoration.first, 0);
961  }
962  }
963  }
964  }
965  }
966  }
967 
968  uint32_t structID = operands[0];
969  std::string structIdentifier = nameMap.lookup(structID).str();
970 
971  if (structIdentifier.empty()) {
972  assert(unresolvedMemberTypes.empty() &&
973  "didn't expect unresolved member types");
974  typeMap[structID] =
975  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
976  } else {
977  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
978  typeMap[structID] = structTy;
979 
980  if (!unresolvedMemberTypes.empty())
981  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
982  memberTypes, offsetInfo,
983  memberDecorationsInfo});
984  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
985  memberDecorationsInfo)))
986  return failure();
987  }
988 
989  // TODO: Update StructType to have member name as attribute as
990  // well.
991  return success();
992 }
993 
995 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
996  if (operands.size() != 3) {
997  // Three operands are needed: result_id, column_type, and column_count
998  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
999  " (result_id, column_type, and column_count)");
1000  }
1001  // Matrix columns must be of vector type
1002  Type elementTy = getType(operands[1]);
1003  if (!elementTy) {
1004  return emitError(unknownLoc,
1005  "OpTypeMatrix references undefined column type.")
1006  << operands[1];
1007  }
1008 
1009  uint32_t colsCount = operands[2];
1010  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1011  return success();
1012 }
1013 
1015 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1016  if (operands.size() != 2)
1017  return emitError(unknownLoc,
1018  "OpTypeForwardPointer instruction must have two operands");
1019 
1020  typeForwardPointerIDs.insert(operands[0]);
1021  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1022  // instruction that defines the actual type.
1023 
1024  return success();
1025 }
1026 
1028 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1029  // TODO: Add support for Access Qualifier.
1030  if (operands.size() != 8)
1031  return emitError(
1032  unknownLoc,
1033  "OpTypeImage with non-eight operands are not supported yet");
1034 
1035  Type elementTy = getType(operands[1]);
1036  if (!elementTy)
1037  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1038  << operands[1];
1039 
1040  auto dim = spirv::symbolizeDim(operands[2]);
1041  if (!dim)
1042  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1043  << operands[2];
1044 
1045  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1046  if (!depthInfo)
1047  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1048  << operands[3];
1049 
1050  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1051  if (!arrayedInfo)
1052  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1053  << operands[4];
1054 
1055  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1056  if (!samplingInfo)
1057  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1058 
1059  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1060  if (!samplerUseInfo)
1061  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1062  << operands[6];
1063 
1064  auto format = spirv::symbolizeImageFormat(operands[7]);
1065  if (!format)
1066  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1067  << operands[7];
1068 
1069  typeMap[operands[0]] = spirv::ImageType::get(
1070  elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(),
1071  samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue());
1072  return success();
1073 }
1074 
1076 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1077  if (operands.size() != 2)
1078  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1079 
1080  Type elementTy = getType(operands[1]);
1081  if (!elementTy)
1082  return emitError(unknownLoc,
1083  "OpTypeSampledImage references undefined <id>: ")
1084  << operands[1];
1085 
1086  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1087  return success();
1088 }
1089 
1090 //===----------------------------------------------------------------------===//
1091 // Constant
1092 //===----------------------------------------------------------------------===//
1093 
1094 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1095  bool isSpec) {
1096  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1097 
1098  if (operands.size() < 2) {
1099  return emitError(unknownLoc)
1100  << opname << " must have type <id> and result <id>";
1101  }
1102  if (operands.size() < 3) {
1103  return emitError(unknownLoc)
1104  << opname << " must have at least 1 more parameter";
1105  }
1106 
1107  Type resultType = getType(operands[0]);
1108  if (!resultType) {
1109  return emitError(unknownLoc, "undefined result type from <id> ")
1110  << operands[0];
1111  }
1112 
1113  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1114  if (bitwidth == 64) {
1115  if (operands.size() == 4) {
1116  return success();
1117  }
1118  return emitError(unknownLoc)
1119  << opname << " should have 2 parameters for 64-bit values";
1120  }
1121  if (bitwidth <= 32) {
1122  if (operands.size() == 3) {
1123  return success();
1124  }
1125 
1126  return emitError(unknownLoc)
1127  << opname
1128  << " should have 1 parameter for values with no more than 32 bits";
1129  }
1130  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1131  << bitwidth;
1132  };
1133 
1134  auto resultID = operands[1];
1135 
1136  if (auto intType = resultType.dyn_cast<IntegerType>()) {
1137  auto bitwidth = intType.getWidth();
1138  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1139  return failure();
1140  }
1141 
1142  APInt value;
1143  if (bitwidth == 64) {
1144  // 64-bit integers are represented with two SPIR-V words. According to
1145  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1146  // literal’s low-order words appear first."
1147  struct DoubleWord {
1148  uint32_t word1;
1149  uint32_t word2;
1150  } words = {operands[2], operands[3]};
1151  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1152  } else if (bitwidth <= 32) {
1153  value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1154  }
1155 
1156  auto attr = opBuilder.getIntegerAttr(intType, value);
1157 
1158  if (isSpec) {
1159  createSpecConstant(unknownLoc, resultID, attr);
1160  } else {
1161  // For normal constants, we just record the attribute (and its type) for
1162  // later materialization at use sites.
1163  constantMap.try_emplace(resultID, attr, intType);
1164  }
1165 
1166  return success();
1167  }
1168 
1169  if (auto floatType = resultType.dyn_cast<FloatType>()) {
1170  auto bitwidth = floatType.getWidth();
1171  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1172  return failure();
1173  }
1174 
1175  APFloat value(0.f);
1176  if (floatType.isF64()) {
1177  // Double values are represented with two SPIR-V words. According to
1178  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1179  // literal’s low-order words appear first."
1180  struct DoubleWord {
1181  uint32_t word1;
1182  uint32_t word2;
1183  } words = {operands[2], operands[3]};
1184  value = APFloat(llvm::bit_cast<double>(words));
1185  } else if (floatType.isF32()) {
1186  value = APFloat(llvm::bit_cast<float>(operands[2]));
1187  } else if (floatType.isF16()) {
1188  APInt data(16, operands[2]);
1189  value = APFloat(APFloat::IEEEhalf(), data);
1190  }
1191 
1192  auto attr = opBuilder.getFloatAttr(floatType, value);
1193  if (isSpec) {
1194  createSpecConstant(unknownLoc, resultID, attr);
1195  } else {
1196  // For normal constants, we just record the attribute (and its type) for
1197  // later materialization at use sites.
1198  constantMap.try_emplace(resultID, attr, floatType);
1199  }
1200 
1201  return success();
1202  }
1203 
1204  return emitError(unknownLoc, "OpConstant can only generate values of "
1205  "scalar integer or floating-point type");
1206 }
1207 
1208 LogicalResult spirv::Deserializer::processConstantBool(
1209  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1210  if (operands.size() != 2) {
1211  return emitError(unknownLoc, "Op")
1212  << (isSpec ? "Spec" : "") << "Constant"
1213  << (isTrue ? "True" : "False")
1214  << " must have type <id> and result <id>";
1215  }
1216 
1217  auto attr = opBuilder.getBoolAttr(isTrue);
1218  auto resultID = operands[1];
1219  if (isSpec) {
1220  createSpecConstant(unknownLoc, resultID, attr);
1221  } else {
1222  // For normal constants, we just record the attribute (and its type) for
1223  // later materialization at use sites.
1224  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1225  }
1226 
1227  return success();
1228 }
1229 
1231 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1232  if (operands.size() < 2) {
1233  return emitError(unknownLoc,
1234  "OpConstantComposite must have type <id> and result <id>");
1235  }
1236  if (operands.size() < 3) {
1237  return emitError(unknownLoc,
1238  "OpConstantComposite must have at least 1 parameter");
1239  }
1240 
1241  Type resultType = getType(operands[0]);
1242  if (!resultType) {
1243  return emitError(unknownLoc, "undefined result type from <id> ")
1244  << operands[0];
1245  }
1246 
1247  SmallVector<Attribute, 4> elements;
1248  elements.reserve(operands.size() - 2);
1249  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1250  auto elementInfo = getConstant(operands[i]);
1251  if (!elementInfo) {
1252  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1253  << operands[i] << " must come from a normal constant";
1254  }
1255  elements.push_back(elementInfo->first);
1256  }
1257 
1258  auto resultID = operands[1];
1259  if (auto vectorType = resultType.dyn_cast<VectorType>()) {
1260  auto attr = DenseElementsAttr::get(vectorType, elements);
1261  // For normal constants, we just record the attribute (and its type) for
1262  // later materialization at use sites.
1263  constantMap.try_emplace(resultID, attr, resultType);
1264  } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
1265  auto attr = opBuilder.getArrayAttr(elements);
1266  constantMap.try_emplace(resultID, attr, resultType);
1267  } else {
1268  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1269  << resultType;
1270  }
1271 
1272  return success();
1273 }
1274 
1276 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1277  if (operands.size() < 2) {
1278  return emitError(unknownLoc,
1279  "OpConstantComposite must have type <id> and result <id>");
1280  }
1281  if (operands.size() < 3) {
1282  return emitError(unknownLoc,
1283  "OpConstantComposite must have at least 1 parameter");
1284  }
1285 
1286  Type resultType = getType(operands[0]);
1287  if (!resultType) {
1288  return emitError(unknownLoc, "undefined result type from <id> ")
1289  << operands[0];
1290  }
1291 
1292  auto resultID = operands[1];
1293  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1294 
1295  SmallVector<Attribute, 4> elements;
1296  elements.reserve(operands.size() - 2);
1297  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1298  auto elementInfo = getSpecConstant(operands[i]);
1299  elements.push_back(SymbolRefAttr::get(elementInfo));
1300  }
1301 
1302  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1303  unknownLoc, TypeAttr::get(resultType), symName,
1304  opBuilder.getArrayAttr(elements));
1305  specConstCompositeMap[resultID] = op;
1306 
1307  return success();
1308 }
1309 
1311 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1312  if (operands.size() < 3)
1313  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1314  "result <id>, and operand opcode");
1315 
1316  uint32_t resultTypeID = operands[0];
1317 
1318  if (!getType(resultTypeID))
1319  return emitError(unknownLoc, "undefined result type from <id> ")
1320  << resultTypeID;
1321 
1322  uint32_t resultID = operands[1];
1323  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1324  auto emplaceResult = specConstOperationMap.try_emplace(
1325  resultID,
1327  enclosedOpcode, resultTypeID,
1328  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1329 
1330  if (!emplaceResult.second)
1331  return emitError(unknownLoc, "value with <id>: ")
1332  << resultID << " is probably defined before.";
1333 
1334  return success();
1335 }
1336 
1337 Value spirv::Deserializer::materializeSpecConstantOperation(
1338  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1339  ArrayRef<uint32_t> enclosedOpOperands) {
1340 
1341  Type resultType = getType(resultTypeID);
1342 
1343  // Instructions wrapped by OpSpecConstantOp need an ID for their
1344  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1345  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1346  // ID in that map is assigned to the result of the enclosed instruction. Note
1347  // that there is no need to update this fake ID since we only need to
1348  // reference the created Value for the enclosed op from the spv::YieldOp
1349  // created later in this method (both of which are the only values in their
1350  // region: the SpecConstantOperation's region). If we encounter another
1351  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1352  // previous Value assigned to it isn't visible in the current scope anyway.
1353  DenseMap<uint32_t, Value> newValueMap;
1354  llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
1355  newValueMap);
1356  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1357 
1358  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1359  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1360  enclosedOpResultTypeAndOperands.push_back(fakeID);
1361  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1362  enclosedOpOperands.end());
1363 
1364  // Process enclosed instruction before creating the enclosing
1365  // specConstantOperation (and its region). This way, references to constants,
1366  // global variables, and spec constants will be materialized outside the new
1367  // op's region. For more info, see Deserializer::getValue's implementation.
1368  if (failed(
1369  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1370  return Value();
1371 
1372  // Since the enclosed op is emitted in the current block, split it in a
1373  // separate new block.
1374  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1375 
1376  auto loc = createFileLineColLoc(opBuilder);
1377  auto specConstOperationOp =
1378  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1379 
1380  Region &body = specConstOperationOp.body();
1381  // Move the new block into SpecConstantOperation's body.
1382  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1383  Region::iterator(enclosedBlock));
1384  Block &block = body.back();
1385 
1386  // RAII guard to reset the insertion point to the module's region after
1387  // deserializing the body of the specConstantOperation.
1388  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1389  opBuilder.setInsertionPointToEnd(&block);
1390 
1391  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1392  return specConstOperationOp.getResult();
1393 }
1394 
1396 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1397  if (operands.size() != 2) {
1398  return emitError(unknownLoc,
1399  "OpConstantNull must have type <id> and result <id>");
1400  }
1401 
1402  Type resultType = getType(operands[0]);
1403  if (!resultType) {
1404  return emitError(unknownLoc, "undefined result type from <id> ")
1405  << operands[0];
1406  }
1407 
1408  auto resultID = operands[1];
1409  if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
1410  auto attr = opBuilder.getZeroAttr(resultType);
1411  // For normal constants, we just record the attribute (and its type) for
1412  // later materialization at use sites.
1413  constantMap.try_emplace(resultID, attr, resultType);
1414  return success();
1415  }
1416 
1417  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1418  << resultType;
1419 }
1420 
1421 //===----------------------------------------------------------------------===//
1422 // Control flow
1423 //===----------------------------------------------------------------------===//
1424 
1425 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1426  if (auto *block = getBlock(id)) {
1427  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1428  << " @ " << block << "\n");
1429  return block;
1430  }
1431 
1432  // We don't know where this block will be placed finally (in a
1433  // spv.mlir.selection or spv.mlir.loop or function). Create it into the
1434  // function for now and sort out the proper place later.
1435  auto *block = curFunction->addBlock();
1436  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1437  << " @ " << block << "\n");
1438  return blockMap[id] = block;
1439 }
1440 
1441 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1442  if (!curBlock) {
1443  return emitError(unknownLoc, "OpBranch must appear inside a block");
1444  }
1445 
1446  if (operands.size() != 1) {
1447  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1448  }
1449 
1450  auto *target = getOrCreateBlock(operands[0]);
1451  auto loc = createFileLineColLoc(opBuilder);
1452  // The preceding instruction for the OpBranch instruction could be an
1453  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1454  // the same OpLine information.
1455  opBuilder.create<spirv::BranchOp>(loc, target);
1456 
1457  clearDebugLine();
1458  return success();
1459 }
1460 
1462 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1463  if (!curBlock) {
1464  return emitError(unknownLoc,
1465  "OpBranchConditional must appear inside a block");
1466  }
1467 
1468  if (operands.size() != 3 && operands.size() != 5) {
1469  return emitError(unknownLoc,
1470  "OpBranchConditional must have condition, true label, "
1471  "false label, and optionally two branch weights");
1472  }
1473 
1474  auto condition = getValue(operands[0]);
1475  auto *trueBlock = getOrCreateBlock(operands[1]);
1476  auto *falseBlock = getOrCreateBlock(operands[2]);
1477 
1479  if (operands.size() == 5) {
1480  weights = std::make_pair(operands[3], operands[4]);
1481  }
1482  // The preceding instruction for the OpBranchConditional instruction could be
1483  // an OpSelectionMerge instruction, in this case they will have the same
1484  // OpLine information.
1485  auto loc = createFileLineColLoc(opBuilder);
1486  opBuilder.create<spirv::BranchConditionalOp>(
1487  loc, condition, trueBlock,
1488  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1489  /*falseArguments=*/ArrayRef<Value>(), weights);
1490 
1491  clearDebugLine();
1492  return success();
1493 }
1494 
1495 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1496  if (!curFunction) {
1497  return emitError(unknownLoc, "OpLabel must appear inside a function");
1498  }
1499 
1500  if (operands.size() != 1) {
1501  return emitError(unknownLoc, "OpLabel should only have result <id>");
1502  }
1503 
1504  auto labelID = operands[0];
1505  // We may have forward declared this block.
1506  auto *block = getOrCreateBlock(labelID);
1507  LLVM_DEBUG(logger.startLine()
1508  << "[block] populating block " << block << "\n");
1509  // If we have seen this block, make sure it was just a forward declaration.
1510  assert(block->empty() && "re-deserialize the same block!");
1511 
1512  opBuilder.setInsertionPointToStart(block);
1513  blockMap[labelID] = curBlock = block;
1514 
1515  return success();
1516 }
1517 
1519 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1520  if (!curBlock) {
1521  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1522  }
1523 
1524  if (operands.size() < 2) {
1525  return emitError(
1526  unknownLoc,
1527  "OpSelectionMerge must specify merge target and selection control");
1528  }
1529 
1530  auto *mergeBlock = getOrCreateBlock(operands[0]);
1531  auto loc = createFileLineColLoc(opBuilder);
1532  auto selectionControl = operands[1];
1533 
1534  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1535  .second) {
1536  return emitError(
1537  unknownLoc,
1538  "a block cannot have more than one OpSelectionMerge instruction");
1539  }
1540 
1541  return success();
1542 }
1543 
1545 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1546  if (!curBlock) {
1547  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1548  }
1549 
1550  if (operands.size() < 3) {
1551  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1552  "continue target and loop control");
1553  }
1554 
1555  auto *mergeBlock = getOrCreateBlock(operands[0]);
1556  auto *continueBlock = getOrCreateBlock(operands[1]);
1557  auto loc = createFileLineColLoc(opBuilder);
1558  uint32_t loopControl = operands[2];
1559 
1560  if (!blockMergeInfo
1561  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1562  .second) {
1563  return emitError(
1564  unknownLoc,
1565  "a block cannot have more than one OpLoopMerge instruction");
1566  }
1567 
1568  return success();
1569 }
1570 
1571 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1572  if (!curBlock) {
1573  return emitError(unknownLoc, "OpPhi must appear in a block");
1574  }
1575 
1576  if (operands.size() < 4) {
1577  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1578  "and variable-parent pairs");
1579  }
1580 
1581  // Create a block argument for this OpPhi instruction.
1582  Type blockArgType = getType(operands[0]);
1583  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1584  valueMap[operands[1]] = blockArg;
1585  LLVM_DEBUG(logger.startLine()
1586  << "[phi] created block argument " << blockArg
1587  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1588 
1589  // For each (value, predecessor) pair, insert the value to the predecessor's
1590  // blockPhiInfo entry so later we can fix the block argument there.
1591  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1592  uint32_t value = operands[i];
1593  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1594  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1595  blockPhiInfo[predecessorTargetPair].push_back(value);
1596  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1597  << " with arg id = " << value << "\n");
1598  }
1599 
1600  return success();
1601 }
1602 
1603 namespace {
1604 /// A class for putting all blocks in a structured selection/loop in a
1605 /// spv.mlir.selection/spv.mlir.loop op.
1606 class ControlFlowStructurizer {
1607 public:
1608 #ifndef NDEBUG
1609  ControlFlowStructurizer(Location loc, uint32_t control,
1610  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1611  Block *merge, Block *cont,
1612  llvm::ScopedPrinter &logger)
1613  : location(loc), control(control), blockMergeInfo(mergeInfo),
1614  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1615  logger(logger) {}
1616 #else
1617  ControlFlowStructurizer(Location loc, uint32_t control,
1618  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1619  Block *merge, Block *cont)
1620  : location(loc), control(control), blockMergeInfo(mergeInfo),
1621  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1622 #endif
1623 
1624  /// Structurizes the loop at the given `headerBlock`.
1625  ///
1626  /// This method will create an spv.mlir.loop op in the `mergeBlock` and move
1627  /// all blocks in the structured loop into the spv.mlir.loop's region. All
1628  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1629  /// method will also update `mergeInfo` by remapping all blocks inside to the
1630  /// newly cloned ones inside structured control flow op's regions.
1631  LogicalResult structurize();
1632 
1633 private:
1634  /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`.
1635  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1636 
1637  /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`.
1638  spirv::LoopOp createLoopOp(uint32_t loopControl);
1639 
1640  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1641  void collectBlocksInConstruct();
1642 
1643  Location location;
1644  uint32_t control;
1645 
1646  spirv::BlockMergeInfoMap &blockMergeInfo;
1647 
1648  Block *headerBlock;
1649  Block *mergeBlock;
1650  Block *continueBlock; // nullptr for spv.mlir.selection
1651 
1652  SetVector<Block *> constructBlocks;
1653 
1654 #ifndef NDEBUG
1655  /// A logger used to emit information during the deserialzation process.
1656  llvm::ScopedPrinter &logger;
1657 #endif
1658 };
1659 } // namespace
1660 
1661 spirv::SelectionOp
1662 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1663  // Create a builder and set the insertion point to the beginning of the
1664  // merge block so that the newly created SelectionOp will be inserted there.
1665  OpBuilder builder(&mergeBlock->front());
1666 
1667  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1668  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1669  selectionOp.addMergeBlock();
1670 
1671  return selectionOp;
1672 }
1673 
1674 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1675  // Create a builder and set the insertion point to the beginning of the
1676  // merge block so that the newly created LoopOp will be inserted there.
1677  OpBuilder builder(&mergeBlock->front());
1678 
1679  auto control = static_cast<spirv::LoopControl>(loopControl);
1680  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1681  loopOp.addEntryAndMergeBlock();
1682 
1683  return loopOp;
1684 }
1685 
1686 void ControlFlowStructurizer::collectBlocksInConstruct() {
1687  assert(constructBlocks.empty() && "expected empty constructBlocks");
1688 
1689  // Put the header block in the work list first.
1690  constructBlocks.insert(headerBlock);
1691 
1692  // For each item in the work list, add its successors excluding the merge
1693  // block.
1694  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1695  for (auto *successor : constructBlocks[i]->getSuccessors())
1696  if (successor != mergeBlock)
1697  constructBlocks.insert(successor);
1698  }
1699 }
1700 
1701 LogicalResult ControlFlowStructurizer::structurize() {
1702  Operation *op = nullptr;
1703  bool isLoop = continueBlock != nullptr;
1704  if (isLoop) {
1705  if (auto loopOp = createLoopOp(control))
1706  op = loopOp.getOperation();
1707  } else {
1708  if (auto selectionOp = createSelectionOp(control))
1709  op = selectionOp.getOperation();
1710  }
1711  if (!op)
1712  return failure();
1713  Region &body = op->getRegion(0);
1714 
1715  BlockAndValueMapping mapper;
1716  // All references to the old merge block should be directed to the
1717  // selection/loop merge block in the SelectionOp/LoopOp's region.
1718  mapper.map(mergeBlock, &body.back());
1719 
1720  collectBlocksInConstruct();
1721 
1722  // We've identified all blocks belonging to the selection/loop's region. Now
1723  // need to "move" them into the selection/loop. Instead of really moving the
1724  // blocks, in the following we copy them and remap all values and branches.
1725  // This is because:
1726  // * Inserting a block into a region requires the block not in any region
1727  // before. But selections/loops can nest so we can create selection/loop ops
1728  // in a nested manner, which means some blocks may already be in a
1729  // selection/loop region when to be moved again.
1730  // * It's much trickier to fix up the branches into and out of the loop's
1731  // region: we need to treat not-moved blocks and moved blocks differently:
1732  // Not-moved blocks jumping to the loop header block need to jump to the
1733  // merge point containing the new loop op but not the loop continue block's
1734  // back edge. Moved blocks jumping out of the loop need to jump to the
1735  // merge block inside the loop region but not other not-moved blocks.
1736  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1737  // logic.
1738 
1739  // Create a corresponding block in the SelectionOp/LoopOp's region for each
1740  // block in this loop construct.
1741  OpBuilder builder(body);
1742  for (auto *block : constructBlocks) {
1743  // Create a block and insert it before the selection/loop merge block in the
1744  // SelectionOp/LoopOp's region.
1745  auto *newBlock = builder.createBlock(&body.back());
1746  mapper.map(block, newBlock);
1747  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1748  << " from block " << block << "\n");
1749  if (!isFnEntryBlock(block)) {
1750  for (BlockArgument blockArg : block->getArguments()) {
1751  auto newArg =
1752  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1753  mapper.map(blockArg, newArg);
1754  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1755  << blockArg << " to " << newArg << "\n");
1756  }
1757  } else {
1758  LLVM_DEBUG(logger.startLine()
1759  << "[cf] block " << block << " is a function entry block\n");
1760  }
1761 
1762  for (auto &op : *block)
1763  newBlock->push_back(op.clone(mapper));
1764  }
1765 
1766  // Go through all ops and remap the operands.
1767  auto remapOperands = [&](Operation *op) {
1768  for (auto &operand : op->getOpOperands())
1769  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1770  operand.set(mappedOp);
1771  for (auto &succOp : op->getBlockOperands())
1772  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1773  succOp.set(mappedOp);
1774  };
1775  for (auto &block : body)
1776  block.walk(remapOperands);
1777 
1778  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1779  // the selection/loop construct into its region. Next we need to fix the
1780  // connections between this new SelectionOp/LoopOp with existing blocks.
1781 
1782  // All existing incoming branches should go to the merge block, where the
1783  // SelectionOp/LoopOp resides right now.
1784  headerBlock->replaceAllUsesWith(mergeBlock);
1785 
1786  LLVM_DEBUG({
1787  logger.startLine() << "[cf] after cloning and fixing references:\n";
1788  headerBlock->getParentOp()->print(logger.getOStream());
1789  logger.startLine() << "\n";
1790  });
1791 
1792  if (isLoop) {
1793  if (!mergeBlock->args_empty()) {
1794  return mergeBlock->getParentOp()->emitError(
1795  "OpPhi in loop merge block unsupported");
1796  }
1797 
1798  // The loop header block may have block arguments. Since now we place the
1799  // loop op inside the old merge block, we need to make sure the old merge
1800  // block has the same block argument list.
1801  for (BlockArgument blockArg : headerBlock->getArguments())
1802  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1803 
1804  // If the loop header block has block arguments, make sure the spv.Branch op
1805  // matches.
1806  SmallVector<Value, 4> blockArgs;
1807  if (!headerBlock->args_empty())
1808  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1809 
1810  // The loop entry block should have a unconditional branch jumping to the
1811  // loop header block.
1812  builder.setInsertionPointToEnd(&body.front());
1813  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1814  ArrayRef<Value>(blockArgs));
1815  }
1816 
1817  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1818  // cleaned up.
1819  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1820  // First we need to drop all operands' references inside all blocks. This is
1821  // needed because we can have blocks referencing SSA values from one another.
1822  for (auto *block : constructBlocks)
1823  block->dropAllReferences();
1824 
1825  // Check that whether some op in the to-be-erased blocks still has uses. Those
1826  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1827  // region. We cannot handle such cases given that once a value is sinked into
1828  // the SelectionOp/LoopOp's region, there is no escape for it:
1829  // SelectionOp/LooOp does not support yield values right now.
1830  for (auto *block : constructBlocks) {
1831  for (Operation &op : *block)
1832  if (!op.use_empty())
1833  return op.emitOpError(
1834  "failed control flow structurization: it has uses outside of the "
1835  "enclosing selection/loop construct");
1836  }
1837 
1838  // Then erase all old blocks.
1839  for (auto *block : constructBlocks) {
1840  // We've cloned all blocks belonging to this construct into the structured
1841  // control flow op's region. Among these blocks, some may compose another
1842  // selection/loop. If so, they will be recorded within blockMergeInfo.
1843  // We need to update the pointers there to the newly remapped ones so we can
1844  // continue structurizing them later.
1845  // TODO: The asserts in the following assumes input SPIR-V blob forms
1846  // correctly nested selection/loop constructs. We should relax this and
1847  // support error cases better.
1848  auto it = blockMergeInfo.find(block);
1849  if (it != blockMergeInfo.end()) {
1850  // Use the original location for nested selection/loop ops.
1851  Location loc = it->second.loc;
1852 
1853  Block *newHeader = mapper.lookupOrNull(block);
1854  if (!newHeader)
1855  return emitError(loc, "failed control flow structurization: nested "
1856  "loop header block should be remapped!");
1857 
1858  Block *newContinue = it->second.continueBlock;
1859  if (newContinue) {
1860  newContinue = mapper.lookupOrNull(newContinue);
1861  if (!newContinue)
1862  return emitError(loc, "failed control flow structurization: nested "
1863  "loop continue block should be remapped!");
1864  }
1865 
1866  Block *newMerge = it->second.mergeBlock;
1867  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
1868  newMerge = mappedTo;
1869 
1870  // The iterator should be erased before adding a new entry into
1871  // blockMergeInfo to avoid iterator invalidation.
1872  blockMergeInfo.erase(it);
1873  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
1874  newContinue);
1875  }
1876 
1877  // The structured selection/loop's entry block does not have arguments.
1878  // If the function's header block is also part of the structured control
1879  // flow, we cannot just simply erase it because it may contain arguments
1880  // matching the function signature and used by the cloned blocks.
1881  if (isFnEntryBlock(block)) {
1882  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
1883  << " to only contain a spv.Branch op\n");
1884  // Still keep the function entry block for the potential block arguments,
1885  // but replace all ops inside with a branch to the merge block.
1886  block->clear();
1887  builder.setInsertionPointToEnd(block);
1888  builder.create<spirv::BranchOp>(location, mergeBlock);
1889  } else {
1890  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
1891  block->erase();
1892  }
1893  }
1894 
1895  LLVM_DEBUG(logger.startLine()
1896  << "[cf] after structurizing construct with header block "
1897  << headerBlock << ":\n"
1898  << *op << "\n");
1899 
1900  return success();
1901 }
1902 
1903 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
1904  LLVM_DEBUG({
1905  logger.startLine()
1906  << "//----- [phi] start wiring up block arguments -----//\n";
1907  logger.indent();
1908  });
1909 
1910  OpBuilder::InsertionGuard guard(opBuilder);
1911 
1912  for (const auto &info : blockPhiInfo) {
1913  Block *block = info.first.first;
1914  Block *target = info.first.second;
1915  const BlockPhiInfo &phiInfo = info.second;
1916  LLVM_DEBUG({
1917  logger.startLine() << "[phi] block " << block << "\n";
1918  logger.startLine() << "[phi] before creating block argument:\n";
1919  block->getParentOp()->print(logger.getOStream());
1920  logger.startLine() << "\n";
1921  });
1922 
1923  // Set insertion point to before this block's terminator early because we
1924  // may materialize ops via getValue() call.
1925  auto *op = block->getTerminator();
1926  opBuilder.setInsertionPoint(op);
1927 
1928  SmallVector<Value, 4> blockArgs;
1929  blockArgs.reserve(phiInfo.size());
1930  for (uint32_t valueId : phiInfo) {
1931  if (Value value = getValue(valueId)) {
1932  blockArgs.push_back(value);
1933  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
1934  << " id = " << valueId << "\n");
1935  } else {
1936  return emitError(unknownLoc, "OpPhi references undefined value!");
1937  }
1938  }
1939 
1940  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
1941  // Replace the previous branch op with a new one with block arguments.
1942  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
1943  blockArgs);
1944  branchOp.erase();
1945  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
1946  assert((branchCondOp.getTrueBlock() == target ||
1947  branchCondOp.getFalseBlock() == target) &&
1948  "expected target to be either the true or false target");
1949  if (target == branchCondOp.trueTarget())
1950  opBuilder.create<spirv::BranchConditionalOp>(
1951  branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
1952  branchCondOp.getFalseBlockArguments(),
1953  branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
1954  branchCondOp.falseTarget());
1955  else
1956  opBuilder.create<spirv::BranchConditionalOp>(
1957  branchCondOp.getLoc(), branchCondOp.condition(),
1958  branchCondOp.getTrueBlockArguments(), blockArgs,
1959  branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
1960  branchCondOp.getFalseBlock());
1961 
1962  branchCondOp.erase();
1963  } else {
1964  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
1965  }
1966 
1967  LLVM_DEBUG({
1968  logger.startLine() << "[phi] after creating block argument:\n";
1969  block->getParentOp()->print(logger.getOStream());
1970  logger.startLine() << "\n";
1971  });
1972  }
1973  blockPhiInfo.clear();
1974 
1975  LLVM_DEBUG({
1976  logger.unindent();
1977  logger.startLine()
1978  << "//--- [phi] completed wiring up block arguments ---//\n";
1979  });
1980  return success();
1981 }
1982 
1983 LogicalResult spirv::Deserializer::structurizeControlFlow() {
1984  LLVM_DEBUG({
1985  logger.startLine()
1986  << "//----- [cf] start structurizing control flow -----//\n";
1987  logger.indent();
1988  });
1989 
1990  while (!blockMergeInfo.empty()) {
1991  Block *headerBlock = blockMergeInfo.begin()->first;
1992  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
1993 
1994  LLVM_DEBUG({
1995  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
1996  headerBlock->print(logger.getOStream());
1997  logger.startLine() << "\n";
1998  });
1999 
2000  auto *mergeBlock = mergeInfo.mergeBlock;
2001  assert(mergeBlock && "merge block cannot be nullptr");
2002  if (!mergeBlock->args_empty())
2003  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2004  LLVM_DEBUG({
2005  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2006  mergeBlock->print(logger.getOStream());
2007  logger.startLine() << "\n";
2008  });
2009 
2010  auto *continueBlock = mergeInfo.continueBlock;
2011  LLVM_DEBUG(if (continueBlock) {
2012  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2013  continueBlock->print(logger.getOStream());
2014  logger.startLine() << "\n";
2015  });
2016  // Erase this case before calling into structurizer, who will update
2017  // blockMergeInfo.
2018  blockMergeInfo.erase(blockMergeInfo.begin());
2019  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2020  blockMergeInfo, headerBlock,
2021  mergeBlock, continueBlock
2022 #ifndef NDEBUG
2023  ,
2024  logger
2025 #endif
2026  );
2027  if (failed(structurizer.structurize()))
2028  return failure();
2029  }
2030 
2031  LLVM_DEBUG({
2032  logger.unindent();
2033  logger.startLine()
2034  << "//--- [cf] completed structurizing control flow ---//\n";
2035  });
2036  return success();
2037 }
2038 
2039 //===----------------------------------------------------------------------===//
2040 // Debug
2041 //===----------------------------------------------------------------------===//
2042 
2043 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2044  if (!debugLine)
2045  return unknownLoc;
2046 
2047  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2048  if (fileName.empty())
2049  fileName = "<unknown>";
2050  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2051  debugLine->column);
2052 }
2053 
2055 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2056  // According to SPIR-V spec:
2057  // "This location information applies to the instructions physically
2058  // following this instruction, up to the first occurrence of any of the
2059  // following: the next end of block, the next OpLine instruction, or the next
2060  // OpNoLine instruction."
2061  if (operands.size() != 3)
2062  return emitError(unknownLoc, "OpLine must have 3 operands");
2063  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2064  return success();
2065 }
2066 
2067 void spirv::Deserializer::clearDebugLine() { debugLine = llvm::None; }
2068 
2070 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2071  if (operands.size() < 2)
2072  return emitError(unknownLoc, "OpString needs at least 2 operands");
2073 
2074  if (!debugInfoMap.lookup(operands[0]).empty())
2075  return emitError(unknownLoc,
2076  "duplicate debug string found for result <id> ")
2077  << operands[0];
2078 
2079  unsigned wordIndex = 1;
2080  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2081  if (wordIndex != operands.size())
2082  return emitError(unknownLoc,
2083  "unexpected trailing words in OpString instruction");
2084 
2085  debugInfoMap[operands[0]] = debugString;
2086  return success();
2087 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, ArrayRef< NamedAttribute > attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:27
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:221
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Operation & back()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
BlockListType & getBlocks()
Definition: Region.h:45
Block represents an ordered list of Operations.
Definition: Block.h:29
A symbol reference with a reference path containing a single element.
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
FloatType getF16Type()
Definition: Builders.cpp:38
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
BlockListType::iterator iterator
Definition: Region.h:52
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
NoneType getNoneType()
Definition: Builders.cpp:75
FloatType getF32Type()
Definition: Builders.cpp:40
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:448
Operation & front()
Definition: Block.h:144
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:952
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:307
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:723
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:222
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:48
A struct for containing OpLine instruction information.
Definition: Deserializer.h:52
U dyn_cast() const
Definition: Types.h:256
Operation * clone(BlockAndValueMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:564
Block * lookupOrNull(Block *from) const
Lookup a mapped value within the map.
UnitAttr getUnitAttr()
Definition: Builders.cpp:85
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Block & back()
Definition: Region.h:64
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:331
IntegerType getI1Type()
Definition: Builders.cpp:50
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
This represents an operation in an abstracted form, suitable for use with the builder APIs...
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:493
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:975
This class represents an argument of a Block.
Definition: Value.h:300
void print(raw_ostream &os, const OpPrintingFlags &flags=llvm::None)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:627
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
constexpr uint32_t kMagicNumber
SPIR-V magic number.
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:402
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void push_back(Operation *op)
Definition: Block.h:140
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:163
FloatType getF64Type()
Definition: Builders.cpp:42
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
static std::string debugString(T &&op)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:374
#define MIN_VERSION_CASE(v)
A struct for containing a header block&#39;s merge and continue targets.
Definition: Deserializer.h:37
static MatrixType get(Type columnType, uint32_t columnCount)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:965
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
bool isa() const
Definition: Types.h:246
void print(raw_ostream &os)
A struct that collects the info needed to materialize/emit a SpecConstantOperation op...
Definition: Deserializer.h:100
This class helps build Operations.
Definition: Builders.h:184
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:484
U cast() const
Definition: Types.h:262
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:289