MLIR  16.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context)
53  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54  module(createModuleOp()), opBuilder(module->getRegion())
55 #ifndef NDEBUG
56  ,
57  logger(llvm::dbgs())
58 #endif
59 {
60 }
61 
63  LLVM_DEBUG({
64  logger.resetIndent();
65  logger.startLine()
66  << "//+++---------- start deserialization ----------+++//\n";
67  });
68 
69  if (failed(processHeader()))
70  return failure();
71 
72  spirv::Opcode opcode = spirv::Opcode::OpNop;
73  ArrayRef<uint32_t> operands;
74  auto binarySize = binary.size();
75  while (curOffset < binarySize) {
76  // Slice the next instruction out and populate `opcode` and `operands`.
77  // Internally this also updates `curOffset`.
78  if (failed(sliceInstruction(opcode, operands)))
79  return failure();
80 
81  if (failed(processInstruction(opcode, operands)))
82  return failure();
83  }
84 
85  assert(curOffset == binarySize &&
86  "deserializer should never index beyond the binary end");
87 
88  for (auto &deferred : deferredInstructions) {
89  if (failed(processInstruction(deferred.first, deferred.second, false))) {
90  return failure();
91  }
92  }
93 
94  attachVCETriple();
95 
96  LLVM_DEBUG(logger.startLine()
97  << "//+++-------- completed deserialization --------+++//\n");
98  return success();
99 }
100 
102  return std::move(module);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Module structure
107 //===----------------------------------------------------------------------===//
108 
109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
110  OpBuilder builder(context);
111  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112  spirv::ModuleOp::build(builder, state);
113  return cast<spirv::ModuleOp>(Operation::create(state));
114 }
115 
116 LogicalResult spirv::Deserializer::processHeader() {
117  if (binary.size() < spirv::kHeaderWordCount)
118  return emitError(unknownLoc,
119  "SPIR-V binary module must have a 5-word header");
120 
121  if (binary[0] != spirv::kMagicNumber)
122  return emitError(unknownLoc, "incorrect magic number");
123 
124  // Version number bytes: 0 | major number | minor number | 0
125  uint32_t majorVersion = (binary[1] << 8) >> 24;
126  uint32_t minorVersion = (binary[1] << 16) >> 24;
127  if (majorVersion == 1) {
128  switch (minorVersion) {
129 #define MIN_VERSION_CASE(v) \
130  case v: \
131  version = spirv::Version::V_1_##v; \
132  break
133 
134  MIN_VERSION_CASE(0);
135  MIN_VERSION_CASE(1);
136  MIN_VERSION_CASE(2);
137  MIN_VERSION_CASE(3);
138  MIN_VERSION_CASE(4);
139  MIN_VERSION_CASE(5);
140 #undef MIN_VERSION_CASE
141  default:
142  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
143  << minorVersion;
144  }
145  } else {
146  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
147  << majorVersion;
148  }
149 
150  // TODO: generator number, bound, schema
151  curOffset = spirv::kHeaderWordCount;
152  return success();
153 }
154 
156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
157  if (operands.size() != 1)
158  return emitError(unknownLoc, "OpMemoryModel must have one parameter");
159 
160  auto cap = spirv::symbolizeCapability(operands[0]);
161  if (!cap)
162  return emitError(unknownLoc, "unknown capability: ") << operands[0];
163 
164  capabilities.insert(*cap);
165  return success();
166 }
167 
168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
169  if (words.empty()) {
170  return emitError(
171  unknownLoc,
172  "OpExtension must have a literal string for the extension name");
173  }
174 
175  unsigned wordIndex = 0;
176  StringRef extName = decodeStringLiteral(words, wordIndex);
177  if (wordIndex != words.size())
178  return emitError(unknownLoc,
179  "unexpected trailing words in OpExtension instruction");
180  auto ext = spirv::symbolizeExtension(extName);
181  if (!ext)
182  return emitError(unknownLoc, "unknown extension: ") << extName;
183 
184  extensions.insert(*ext);
185  return success();
186 }
187 
189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
190  if (words.size() < 2) {
191  return emitError(unknownLoc,
192  "OpExtInstImport must have a result <id> and a literal "
193  "string for the extended instruction set name");
194  }
195 
196  unsigned wordIndex = 1;
197  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
198  if (wordIndex != words.size()) {
199  return emitError(unknownLoc,
200  "unexpected trailing words in OpExtInstImport");
201  }
202  return success();
203 }
204 
205 void spirv::Deserializer::attachVCETriple() {
206  (*module)->setAttr(
207  spirv::ModuleOp::getVCETripleAttrName(),
208  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
209  extensions.getArrayRef(), context));
210 }
211 
213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
214  if (operands.size() != 2)
215  return emitError(unknownLoc, "OpMemoryModel must have two operands");
216 
217  (*module)->setAttr(
218  "addressing_model",
219  opBuilder.getAttr<spirv::AddressingModelAttr>(
220  static_cast<spirv::AddressingModel>(operands.front())));
221  (*module)->setAttr("memory_model",
222  opBuilder.getAttr<spirv::MemoryModelAttr>(
223  static_cast<spirv::MemoryModel>(operands.back())));
224 
225  return success();
226 }
227 
228 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
229  // TODO: This function should also be auto-generated. For now, since only a
230  // few decorations are processed/handled in a meaningful manner, going with a
231  // manual implementation.
232  if (words.size() < 2) {
233  return emitError(
234  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
235  }
236  auto decorationName =
237  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
238  if (decorationName.empty()) {
239  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
240  }
241  auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
242  auto symbol = opBuilder.getStringAttr(attrName);
243  switch (static_cast<spirv::Decoration>(words[1])) {
244  case spirv::Decoration::DescriptorSet:
245  case spirv::Decoration::Binding:
246  if (words.size() != 3) {
247  return emitError(unknownLoc, "OpDecorate with ")
248  << decorationName << " needs a single integer literal";
249  }
250  decorations[words[0]].set(
251  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
252  break;
253  case spirv::Decoration::BuiltIn:
254  if (words.size() != 3) {
255  return emitError(unknownLoc, "OpDecorate with ")
256  << decorationName << " needs a single integer literal";
257  }
258  decorations[words[0]].set(
259  symbol, opBuilder.getStringAttr(
260  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
261  break;
262  case spirv::Decoration::ArrayStride:
263  if (words.size() != 3) {
264  return emitError(unknownLoc, "OpDecorate with ")
265  << decorationName << " needs a single integer literal";
266  }
267  typeDecorations[words[0]] = words[2];
268  break;
269  case spirv::Decoration::Aliased:
270  case spirv::Decoration::Block:
271  case spirv::Decoration::BufferBlock:
272  case spirv::Decoration::Flat:
273  case spirv::Decoration::NonReadable:
274  case spirv::Decoration::NonWritable:
275  case spirv::Decoration::NoPerspective:
276  case spirv::Decoration::Restrict:
277  case spirv::Decoration::RelaxedPrecision:
278  if (words.size() != 2) {
279  return emitError(unknownLoc, "OpDecoration with ")
280  << decorationName << "needs a single target <id>";
281  }
282  // Block decoration does not affect spirv.struct type, but is still stored
283  // for verification.
284  // TODO: Update StructType to contain this information since
285  // it is needed for many validation rules.
286  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
287  break;
288  case spirv::Decoration::Location:
289  case spirv::Decoration::SpecId:
290  if (words.size() != 3) {
291  return emitError(unknownLoc, "OpDecoration with ")
292  << decorationName << "needs a single integer literal";
293  }
294  decorations[words[0]].set(
295  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
296  break;
297  default:
298  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
299  }
300  return success();
301 }
302 
304 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
305  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
306  if (words.size() < 3) {
307  return emitError(unknownLoc,
308  "OpMemberDecorate must have at least 3 operands");
309  }
310 
311  auto decoration = static_cast<spirv::Decoration>(words[2]);
312  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
313  return emitError(unknownLoc,
314  " missing offset specification in OpMemberDecorate with "
315  "Offset decoration");
316  }
317  ArrayRef<uint32_t> decorationOperands;
318  if (words.size() > 3) {
319  decorationOperands = words.slice(3);
320  }
321  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
322  return success();
323 }
324 
325 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
326  if (words.size() < 3) {
327  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
328  }
329  unsigned wordIndex = 2;
330  auto name = decodeStringLiteral(words, wordIndex);
331  if (wordIndex != words.size()) {
332  return emitError(unknownLoc,
333  "unexpected trailing words in OpMemberName instruction");
334  }
335  memberNameMap[words[0]][words[1]] = name;
336  return success();
337 }
338 
340 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
341  if (curFunction) {
342  return emitError(unknownLoc, "found function inside function");
343  }
344 
345  // Get the result type
346  if (operands.size() != 4) {
347  return emitError(unknownLoc, "OpFunction must have 4 parameters");
348  }
349  Type resultType = getType(operands[0]);
350  if (!resultType) {
351  return emitError(unknownLoc, "undefined result type from <id> ")
352  << operands[0];
353  }
354 
355  uint32_t fnID = operands[1];
356  if (funcMap.count(fnID)) {
357  return emitError(unknownLoc, "duplicate function definition/declaration");
358  }
359 
360  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
361  if (!fnControl) {
362  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
363  }
364 
365  Type fnType = getType(operands[3]);
366  if (!fnType || !fnType.isa<FunctionType>()) {
367  return emitError(unknownLoc, "unknown function type from <id> ")
368  << operands[3];
369  }
370  auto functionType = fnType.cast<FunctionType>();
371 
372  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
373  (functionType.getNumResults() == 1 &&
374  functionType.getResult(0) != resultType)) {
375  return emitError(unknownLoc, "mismatch in function type ")
376  << functionType << " and return type " << resultType << " specified";
377  }
378 
379  std::string fnName = getFunctionSymbol(fnID);
380  auto funcOp = opBuilder.create<spirv::FuncOp>(
381  unknownLoc, fnName, functionType, fnControl.value());
382  curFunction = funcMap[fnID] = funcOp;
383  auto *entryBlock = funcOp.addEntryBlock();
384  LLVM_DEBUG({
385  logger.startLine()
386  << "//===-------------------------------------------===//\n";
387  logger.startLine() << "[fn] name: " << fnName << "\n";
388  logger.startLine() << "[fn] type: " << fnType << "\n";
389  logger.startLine() << "[fn] ID: " << fnID << "\n";
390  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
391  logger.indent();
392  });
393 
394  // Parse the op argument instructions
395  if (functionType.getNumInputs()) {
396  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
397  auto argType = functionType.getInput(i);
398  spirv::Opcode opcode = spirv::Opcode::OpNop;
399  ArrayRef<uint32_t> operands;
400  if (failed(sliceInstruction(opcode, operands,
401  spirv::Opcode::OpFunctionParameter))) {
402  return failure();
403  }
404  if (opcode != spirv::Opcode::OpFunctionParameter) {
405  return emitError(
406  unknownLoc,
407  "missing OpFunctionParameter instruction for argument ")
408  << i;
409  }
410  if (operands.size() != 2) {
411  return emitError(
412  unknownLoc,
413  "expected result type and result <id> for OpFunctionParameter");
414  }
415  auto argDefinedType = getType(operands[0]);
416  if (!argDefinedType || argDefinedType != argType) {
417  return emitError(unknownLoc,
418  "mismatch in argument type between function type "
419  "definition ")
420  << functionType << " and argument type definition "
421  << argDefinedType << " at argument " << i;
422  }
423  if (getValue(operands[1])) {
424  return emitError(unknownLoc, "duplicate definition of result <id> ")
425  << operands[1];
426  }
427  auto argValue = funcOp.getArgument(i);
428  valueMap[operands[1]] = argValue;
429  }
430  }
431 
432  // RAII guard to reset the insertion point to the module's region after
433  // deserializing the body of this function.
434  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
435 
436  spirv::Opcode opcode = spirv::Opcode::OpNop;
437  ArrayRef<uint32_t> instOperands;
438 
439  // Special handling for the entry block. We need to make sure it starts with
440  // an OpLabel instruction. The entry block takes the same parameters as the
441  // function. All other blocks do not take any parameter. We have already
442  // created the entry block, here we need to register it to the correct label
443  // <id>.
444  if (failed(sliceInstruction(opcode, instOperands,
445  spirv::Opcode::OpFunctionEnd))) {
446  return failure();
447  }
448  if (opcode == spirv::Opcode::OpFunctionEnd) {
449  return processFunctionEnd(instOperands);
450  }
451  if (opcode != spirv::Opcode::OpLabel) {
452  return emitError(unknownLoc, "a basic block must start with OpLabel");
453  }
454  if (instOperands.size() != 1) {
455  return emitError(unknownLoc, "OpLabel should only have result <id>");
456  }
457  blockMap[instOperands[0]] = entryBlock;
458  if (failed(processLabel(instOperands))) {
459  return failure();
460  }
461 
462  // Then process all the other instructions in the function until we hit
463  // OpFunctionEnd.
464  while (succeeded(sliceInstruction(opcode, instOperands,
465  spirv::Opcode::OpFunctionEnd)) &&
466  opcode != spirv::Opcode::OpFunctionEnd) {
467  if (failed(processInstruction(opcode, instOperands))) {
468  return failure();
469  }
470  }
471  if (opcode != spirv::Opcode::OpFunctionEnd) {
472  return failure();
473  }
474 
475  return processFunctionEnd(instOperands);
476 }
477 
479 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
480  // Process OpFunctionEnd.
481  if (!operands.empty()) {
482  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
483  }
484 
485  // Wire up block arguments from OpPhi instructions.
486  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
487  // ops.
488  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
489  return failure();
490  }
491 
492  curBlock = nullptr;
493  curFunction = std::nullopt;
494 
495  LLVM_DEBUG({
496  logger.unindent();
497  logger.startLine()
498  << "//===-------------------------------------------===//\n";
499  });
500  return success();
501 }
502 
504 spirv::Deserializer::getConstant(uint32_t id) {
505  auto constIt = constantMap.find(id);
506  if (constIt == constantMap.end())
507  return std::nullopt;
508  return constIt->getSecond();
509 }
510 
512 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
513  auto constIt = specConstOperationMap.find(id);
514  if (constIt == specConstOperationMap.end())
515  return std::nullopt;
516  return constIt->getSecond();
517 }
518 
519 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
520  auto funcName = nameMap.lookup(id).str();
521  if (funcName.empty()) {
522  funcName = "spirv_fn_" + std::to_string(id);
523  }
524  return funcName;
525 }
526 
527 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
528  auto constName = nameMap.lookup(id).str();
529  if (constName.empty()) {
530  constName = "spirv_spec_const_" + std::to_string(id);
531  }
532  return constName;
533 }
534 
535 spirv::SpecConstantOp
536 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
537  Attribute defaultValue) {
538  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
539  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
540  defaultValue);
541  if (decorations.count(resultID)) {
542  for (auto attr : decorations[resultID].getAttrs())
543  op->setAttr(attr.getName(), attr.getValue());
544  }
545  specConstMap[resultID] = op;
546  return op;
547 }
548 
550 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
551  unsigned wordIndex = 0;
552  if (operands.size() < 3) {
553  return emitError(
554  unknownLoc,
555  "OpVariable needs at least 3 operands, type, <id> and storage class");
556  }
557 
558  // Result Type.
559  auto type = getType(operands[wordIndex]);
560  if (!type) {
561  return emitError(unknownLoc, "unknown result type <id> : ")
562  << operands[wordIndex];
563  }
564  auto ptrType = type.dyn_cast<spirv::PointerType>();
565  if (!ptrType) {
566  return emitError(unknownLoc,
567  "expected a result type <id> to be a spirv.ptr, found : ")
568  << type;
569  }
570  wordIndex++;
571 
572  // Result <id>.
573  auto variableID = operands[wordIndex];
574  auto variableName = nameMap.lookup(variableID).str();
575  if (variableName.empty()) {
576  variableName = "spirv_var_" + std::to_string(variableID);
577  }
578  wordIndex++;
579 
580  // Storage class.
581  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
582  if (ptrType.getStorageClass() != storageClass) {
583  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
584  << type << " and that specified in OpVariable instruction : "
585  << stringifyStorageClass(storageClass);
586  }
587  wordIndex++;
588 
589  // Initializer.
590  FlatSymbolRefAttr initializer = nullptr;
591  if (wordIndex < operands.size()) {
592  auto initializerOp = getGlobalVariable(operands[wordIndex]);
593  if (!initializerOp) {
594  return emitError(unknownLoc, "unknown <id> ")
595  << operands[wordIndex] << "used as initializer";
596  }
597  wordIndex++;
598  initializer = SymbolRefAttr::get(initializerOp.getOperation());
599  }
600  if (wordIndex != operands.size()) {
601  return emitError(unknownLoc,
602  "found more operands than expected when deserializing "
603  "OpVariable instruction, only ")
604  << wordIndex << " of " << operands.size() << " processed";
605  }
606  auto loc = createFileLineColLoc(opBuilder);
607  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
608  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
609  initializer);
610 
611  // Decorations.
612  if (decorations.count(variableID)) {
613  for (auto attr : decorations[variableID].getAttrs())
614  varOp->setAttr(attr.getName(), attr.getValue());
615  }
616  globalVariableMap[variableID] = varOp;
617  return success();
618 }
619 
620 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
621  auto constInfo = getConstant(id);
622  if (!constInfo) {
623  return nullptr;
624  }
625  return constInfo->first.dyn_cast<IntegerAttr>();
626 }
627 
628 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
629  if (operands.size() < 2) {
630  return emitError(unknownLoc, "OpName needs at least 2 operands");
631  }
632  if (!nameMap.lookup(operands[0]).empty()) {
633  return emitError(unknownLoc, "duplicate name found for result <id> ")
634  << operands[0];
635  }
636  unsigned wordIndex = 1;
637  StringRef name = decodeStringLiteral(operands, wordIndex);
638  if (wordIndex != operands.size()) {
639  return emitError(unknownLoc,
640  "unexpected trailing words in OpName instruction");
641  }
642  nameMap[operands[0]] = name;
643  return success();
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // Type
648 //===----------------------------------------------------------------------===//
649 
650 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
651  ArrayRef<uint32_t> operands) {
652  if (operands.empty()) {
653  return emitError(unknownLoc, "type instruction with opcode ")
654  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
655  }
656 
657  /// TODO: Types might be forward declared in some instructions and need to be
658  /// handled appropriately.
659  if (typeMap.count(operands[0])) {
660  return emitError(unknownLoc, "duplicate definition for result <id> ")
661  << operands[0];
662  }
663 
664  switch (opcode) {
665  case spirv::Opcode::OpTypeVoid:
666  if (operands.size() != 1)
667  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
668  typeMap[operands[0]] = opBuilder.getNoneType();
669  break;
670  case spirv::Opcode::OpTypeBool:
671  if (operands.size() != 1)
672  return emitError(unknownLoc, "OpTypeBool must have no parameters");
673  typeMap[operands[0]] = opBuilder.getI1Type();
674  break;
675  case spirv::Opcode::OpTypeInt: {
676  if (operands.size() != 3)
677  return emitError(
678  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
679 
680  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
681  // to preserve or validate.
682  // 0 indicates unsigned, or no signedness semantics
683  // 1 indicates signed semantics."
684  //
685  // So we cannot differentiate signless and unsigned integers; always use
686  // signless semantics for such cases.
687  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
688  : IntegerType::SignednessSemantics::Signless;
689  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
690  } break;
691  case spirv::Opcode::OpTypeFloat: {
692  if (operands.size() != 2)
693  return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
694 
695  Type floatTy;
696  switch (operands[1]) {
697  case 16:
698  floatTy = opBuilder.getF16Type();
699  break;
700  case 32:
701  floatTy = opBuilder.getF32Type();
702  break;
703  case 64:
704  floatTy = opBuilder.getF64Type();
705  break;
706  default:
707  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
708  << operands[1];
709  }
710  typeMap[operands[0]] = floatTy;
711  } break;
712  case spirv::Opcode::OpTypeVector: {
713  if (operands.size() != 3) {
714  return emitError(
715  unknownLoc,
716  "OpTypeVector must have element type and count parameters");
717  }
718  Type elementTy = getType(operands[1]);
719  if (!elementTy) {
720  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
721  << operands[1];
722  }
723  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
724  } break;
725  case spirv::Opcode::OpTypePointer: {
726  return processOpTypePointer(operands);
727  } break;
728  case spirv::Opcode::OpTypeArray:
729  return processArrayType(operands);
730  case spirv::Opcode::OpTypeCooperativeMatrixNV:
731  return processCooperativeMatrixType(operands);
732  case spirv::Opcode::OpTypeFunction:
733  return processFunctionType(operands);
734  case spirv::Opcode::OpTypeJointMatrixINTEL:
735  return processJointMatrixType(operands);
736  case spirv::Opcode::OpTypeImage:
737  return processImageType(operands);
738  case spirv::Opcode::OpTypeSampledImage:
739  return processSampledImageType(operands);
740  case spirv::Opcode::OpTypeRuntimeArray:
741  return processRuntimeArrayType(operands);
742  case spirv::Opcode::OpTypeStruct:
743  return processStructType(operands);
744  case spirv::Opcode::OpTypeMatrix:
745  return processMatrixType(operands);
746  default:
747  return emitError(unknownLoc, "unhandled type instruction");
748  }
749  return success();
750 }
751 
753 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
754  if (operands.size() != 3)
755  return emitError(unknownLoc, "OpTypePointer must have two parameters");
756 
757  auto pointeeType = getType(operands[2]);
758  if (!pointeeType)
759  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
760  << operands[2];
761 
762  uint32_t typePointerID = operands[0];
763  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
764  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
765 
766  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
767  deferredStructIt != std::end(deferredStructTypesInfos);) {
768  for (auto *unresolvedMemberIt =
769  std::begin(deferredStructIt->unresolvedMemberTypes);
770  unresolvedMemberIt !=
771  std::end(deferredStructIt->unresolvedMemberTypes);) {
772  if (unresolvedMemberIt->first == typePointerID) {
773  // The newly constructed pointer type can resolve one of the
774  // deferred struct type members; update the memberTypes list and
775  // clean the unresolvedMemberTypes list accordingly.
776  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
777  typeMap[typePointerID];
778  unresolvedMemberIt =
779  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
780  } else {
781  ++unresolvedMemberIt;
782  }
783  }
784 
785  if (deferredStructIt->unresolvedMemberTypes.empty()) {
786  // All deferred struct type members are now resolved, set the struct body.
787  auto structType = deferredStructIt->deferredStructType;
788 
789  assert(structType && "expected a spirv::StructType");
790  assert(structType.isIdentified() && "expected an indentified struct");
791 
792  if (failed(structType.trySetBody(
793  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
794  deferredStructIt->memberDecorationsInfo)))
795  return failure();
796 
797  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
798  } else {
799  ++deferredStructIt;
800  }
801  }
802 
803  return success();
804 }
805 
807 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
808  if (operands.size() != 3) {
809  return emitError(unknownLoc,
810  "OpTypeArray must have element type and count parameters");
811  }
812 
813  Type elementTy = getType(operands[1]);
814  if (!elementTy) {
815  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
816  << operands[1];
817  }
818 
819  unsigned count = 0;
820  // TODO: The count can also come frome a specialization constant.
821  auto countInfo = getConstant(operands[2]);
822  if (!countInfo) {
823  return emitError(unknownLoc, "OpTypeArray count <id> ")
824  << operands[2] << "can only come from normal constant right now";
825  }
826 
827  if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
828  count = intVal.getValue().getZExtValue();
829  } else {
830  return emitError(unknownLoc, "OpTypeArray count must come from a "
831  "scalar integer constant instruction");
832  }
833 
834  typeMap[operands[0]] = spirv::ArrayType::get(
835  elementTy, count, typeDecorations.lookup(operands[0]));
836  return success();
837 }
838 
840 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
841  assert(!operands.empty() && "No operands for processing function type");
842  if (operands.size() == 1) {
843  return emitError(unknownLoc, "missing return type for OpTypeFunction");
844  }
845  auto returnType = getType(operands[1]);
846  if (!returnType) {
847  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
848  }
849  SmallVector<Type, 1> argTypes;
850  for (size_t i = 2, e = operands.size(); i < e; ++i) {
851  auto ty = getType(operands[i]);
852  if (!ty) {
853  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
854  }
855  argTypes.push_back(ty);
856  }
857  ArrayRef<Type> returnTypes;
858  if (!isVoidType(returnType)) {
859  returnTypes = llvm::makeArrayRef(returnType);
860  }
861  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
862  return success();
863 }
864 
866 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
867  if (operands.size() != 5) {
868  return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
869  "type and row x column parameters");
870  }
871 
872  Type elementTy = getType(operands[1]);
873  if (!elementTy) {
874  return emitError(unknownLoc,
875  "OpTypeCooperativeMatrix references undefined <id> ")
876  << operands[1];
877  }
878 
879  auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
880  if (!scope) {
881  return emitError(unknownLoc,
882  "OpTypeCooperativeMatrix references undefined scope <id> ")
883  << operands[2];
884  }
885 
886  unsigned rows = getConstantInt(operands[3]).getInt();
887  unsigned columns = getConstantInt(operands[4]).getInt();
888 
889  typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
890  elementTy, scope.value(), rows, columns);
891  return success();
892 }
893 
895 spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
896  if (operands.size() != 6) {
897  return emitError(unknownLoc, "OpTypeJointMatrix must have element "
898  "type and row x column parameters");
899  }
900 
901  Type elementTy = getType(operands[1]);
902  if (!elementTy) {
903  return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
904  << operands[1];
905  }
906 
907  auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
908  if (!scope) {
909  return emitError(unknownLoc,
910  "OpTypeJointMatrix references undefined scope <id> ")
911  << operands[5];
912  }
913  auto matrixLayout =
914  spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
915  if (!matrixLayout) {
916  return emitError(unknownLoc,
917  "OpTypeJointMatrix references undefined scope <id> ")
918  << operands[4];
919  }
920  unsigned rows = getConstantInt(operands[2]).getInt();
921  unsigned columns = getConstantInt(operands[3]).getInt();
922 
923  typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
924  elementTy, scope.value(), rows, columns, matrixLayout.value());
925  return success();
926 }
927 
929 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
930  if (operands.size() != 2) {
931  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
932  }
933  Type memberType = getType(operands[1]);
934  if (!memberType) {
935  return emitError(unknownLoc,
936  "OpTypeRuntimeArray references undefined <id> ")
937  << operands[1];
938  }
939  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
940  memberType, typeDecorations.lookup(operands[0]));
941  return success();
942 }
943 
945 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
946  // TODO: Find a way to handle identified structs when debug info is stripped.
947 
948  if (operands.empty()) {
949  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
950  }
951 
952  if (operands.size() == 1) {
953  // Handle empty struct.
954  typeMap[operands[0]] =
955  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
956  return success();
957  }
958 
959  // First element is operand ID, second element is member index in the struct.
960  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
961  SmallVector<Type, 4> memberTypes;
962 
963  for (auto op : llvm::drop_begin(operands, 1)) {
964  Type memberType = getType(op);
965  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
966 
967  if (!memberType && !typeForwardPtr)
968  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
969  << op;
970 
971  if (!memberType)
972  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
973 
974  memberTypes.push_back(memberType);
975  }
976 
979  if (memberDecorationMap.count(operands[0])) {
980  auto &allMemberDecorations = memberDecorationMap[operands[0]];
981  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
982  if (allMemberDecorations.count(memberIndex)) {
983  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
984  // Check for offset.
985  if (memberDecoration.first == spirv::Decoration::Offset) {
986  // If offset info is empty, resize to the number of members;
987  if (offsetInfo.empty()) {
988  offsetInfo.resize(memberTypes.size());
989  }
990  offsetInfo[memberIndex] = memberDecoration.second[0];
991  } else {
992  if (!memberDecoration.second.empty()) {
993  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
994  memberDecoration.first,
995  memberDecoration.second[0]);
996  } else {
997  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
998  memberDecoration.first, 0);
999  }
1000  }
1001  }
1002  }
1003  }
1004  }
1005 
1006  uint32_t structID = operands[0];
1007  std::string structIdentifier = nameMap.lookup(structID).str();
1008 
1009  if (structIdentifier.empty()) {
1010  assert(unresolvedMemberTypes.empty() &&
1011  "didn't expect unresolved member types");
1012  typeMap[structID] =
1013  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1014  } else {
1015  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1016  typeMap[structID] = structTy;
1017 
1018  if (!unresolvedMemberTypes.empty())
1019  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1020  memberTypes, offsetInfo,
1021  memberDecorationsInfo});
1022  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1023  memberDecorationsInfo)))
1024  return failure();
1025  }
1026 
1027  // TODO: Update StructType to have member name as attribute as
1028  // well.
1029  return success();
1030 }
1031 
1033 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1034  if (operands.size() != 3) {
1035  // Three operands are needed: result_id, column_type, and column_count
1036  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1037  " (result_id, column_type, and column_count)");
1038  }
1039  // Matrix columns must be of vector type
1040  Type elementTy = getType(operands[1]);
1041  if (!elementTy) {
1042  return emitError(unknownLoc,
1043  "OpTypeMatrix references undefined column type.")
1044  << operands[1];
1045  }
1046 
1047  uint32_t colsCount = operands[2];
1048  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1049  return success();
1050 }
1051 
1053 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1054  if (operands.size() != 2)
1055  return emitError(unknownLoc,
1056  "OpTypeForwardPointer instruction must have two operands");
1057 
1058  typeForwardPointerIDs.insert(operands[0]);
1059  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1060  // instruction that defines the actual type.
1061 
1062  return success();
1063 }
1064 
1066 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1067  // TODO: Add support for Access Qualifier.
1068  if (operands.size() != 8)
1069  return emitError(
1070  unknownLoc,
1071  "OpTypeImage with non-eight operands are not supported yet");
1072 
1073  Type elementTy = getType(operands[1]);
1074  if (!elementTy)
1075  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1076  << operands[1];
1077 
1078  auto dim = spirv::symbolizeDim(operands[2]);
1079  if (!dim)
1080  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1081  << operands[2];
1082 
1083  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1084  if (!depthInfo)
1085  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1086  << operands[3];
1087 
1088  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1089  if (!arrayedInfo)
1090  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1091  << operands[4];
1092 
1093  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1094  if (!samplingInfo)
1095  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1096 
1097  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1098  if (!samplerUseInfo)
1099  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1100  << operands[6];
1101 
1102  auto format = spirv::symbolizeImageFormat(operands[7]);
1103  if (!format)
1104  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1105  << operands[7];
1106 
1107  typeMap[operands[0]] = spirv::ImageType::get(
1108  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1109  samplingInfo.value(), samplerUseInfo.value(), format.value());
1110  return success();
1111 }
1112 
1114 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1115  if (operands.size() != 2)
1116  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1117 
1118  Type elementTy = getType(operands[1]);
1119  if (!elementTy)
1120  return emitError(unknownLoc,
1121  "OpTypeSampledImage references undefined <id>: ")
1122  << operands[1];
1123 
1124  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1125  return success();
1126 }
1127 
1128 //===----------------------------------------------------------------------===//
1129 // Constant
1130 //===----------------------------------------------------------------------===//
1131 
1132 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1133  bool isSpec) {
1134  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1135 
1136  if (operands.size() < 2) {
1137  return emitError(unknownLoc)
1138  << opname << " must have type <id> and result <id>";
1139  }
1140  if (operands.size() < 3) {
1141  return emitError(unknownLoc)
1142  << opname << " must have at least 1 more parameter";
1143  }
1144 
1145  Type resultType = getType(operands[0]);
1146  if (!resultType) {
1147  return emitError(unknownLoc, "undefined result type from <id> ")
1148  << operands[0];
1149  }
1150 
1151  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1152  if (bitwidth == 64) {
1153  if (operands.size() == 4) {
1154  return success();
1155  }
1156  return emitError(unknownLoc)
1157  << opname << " should have 2 parameters for 64-bit values";
1158  }
1159  if (bitwidth <= 32) {
1160  if (operands.size() == 3) {
1161  return success();
1162  }
1163 
1164  return emitError(unknownLoc)
1165  << opname
1166  << " should have 1 parameter for values with no more than 32 bits";
1167  }
1168  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1169  << bitwidth;
1170  };
1171 
1172  auto resultID = operands[1];
1173 
1174  if (auto intType = resultType.dyn_cast<IntegerType>()) {
1175  auto bitwidth = intType.getWidth();
1176  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1177  return failure();
1178  }
1179 
1180  APInt value;
1181  if (bitwidth == 64) {
1182  // 64-bit integers are represented with two SPIR-V words. According to
1183  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1184  // literal’s low-order words appear first."
1185  struct DoubleWord {
1186  uint32_t word1;
1187  uint32_t word2;
1188  } words = {operands[2], operands[3]};
1189  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1190  } else if (bitwidth <= 32) {
1191  value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1192  }
1193 
1194  auto attr = opBuilder.getIntegerAttr(intType, value);
1195 
1196  if (isSpec) {
1197  createSpecConstant(unknownLoc, resultID, attr);
1198  } else {
1199  // For normal constants, we just record the attribute (and its type) for
1200  // later materialization at use sites.
1201  constantMap.try_emplace(resultID, attr, intType);
1202  }
1203 
1204  return success();
1205  }
1206 
1207  if (auto floatType = resultType.dyn_cast<FloatType>()) {
1208  auto bitwidth = floatType.getWidth();
1209  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1210  return failure();
1211  }
1212 
1213  APFloat value(0.f);
1214  if (floatType.isF64()) {
1215  // Double values are represented with two SPIR-V words. According to
1216  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1217  // literal’s low-order words appear first."
1218  struct DoubleWord {
1219  uint32_t word1;
1220  uint32_t word2;
1221  } words = {operands[2], operands[3]};
1222  value = APFloat(llvm::bit_cast<double>(words));
1223  } else if (floatType.isF32()) {
1224  value = APFloat(llvm::bit_cast<float>(operands[2]));
1225  } else if (floatType.isF16()) {
1226  APInt data(16, operands[2]);
1227  value = APFloat(APFloat::IEEEhalf(), data);
1228  }
1229 
1230  auto attr = opBuilder.getFloatAttr(floatType, value);
1231  if (isSpec) {
1232  createSpecConstant(unknownLoc, resultID, attr);
1233  } else {
1234  // For normal constants, we just record the attribute (and its type) for
1235  // later materialization at use sites.
1236  constantMap.try_emplace(resultID, attr, floatType);
1237  }
1238 
1239  return success();
1240  }
1241 
1242  return emitError(unknownLoc, "OpConstant can only generate values of "
1243  "scalar integer or floating-point type");
1244 }
1245 
1246 LogicalResult spirv::Deserializer::processConstantBool(
1247  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1248  if (operands.size() != 2) {
1249  return emitError(unknownLoc, "Op")
1250  << (isSpec ? "Spec" : "") << "Constant"
1251  << (isTrue ? "True" : "False")
1252  << " must have type <id> and result <id>";
1253  }
1254 
1255  auto attr = opBuilder.getBoolAttr(isTrue);
1256  auto resultID = operands[1];
1257  if (isSpec) {
1258  createSpecConstant(unknownLoc, resultID, attr);
1259  } else {
1260  // For normal constants, we just record the attribute (and its type) for
1261  // later materialization at use sites.
1262  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1263  }
1264 
1265  return success();
1266 }
1267 
1269 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1270  if (operands.size() < 2) {
1271  return emitError(unknownLoc,
1272  "OpConstantComposite must have type <id> and result <id>");
1273  }
1274  if (operands.size() < 3) {
1275  return emitError(unknownLoc,
1276  "OpConstantComposite must have at least 1 parameter");
1277  }
1278 
1279  Type resultType = getType(operands[0]);
1280  if (!resultType) {
1281  return emitError(unknownLoc, "undefined result type from <id> ")
1282  << operands[0];
1283  }
1284 
1285  SmallVector<Attribute, 4> elements;
1286  elements.reserve(operands.size() - 2);
1287  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1288  auto elementInfo = getConstant(operands[i]);
1289  if (!elementInfo) {
1290  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1291  << operands[i] << " must come from a normal constant";
1292  }
1293  elements.push_back(elementInfo->first);
1294  }
1295 
1296  auto resultID = operands[1];
1297  if (auto vectorType = resultType.dyn_cast<VectorType>()) {
1298  auto attr = DenseElementsAttr::get(vectorType, elements);
1299  // For normal constants, we just record the attribute (and its type) for
1300  // later materialization at use sites.
1301  constantMap.try_emplace(resultID, attr, resultType);
1302  } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
1303  auto attr = opBuilder.getArrayAttr(elements);
1304  constantMap.try_emplace(resultID, attr, resultType);
1305  } else {
1306  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1307  << resultType;
1308  }
1309 
1310  return success();
1311 }
1312 
1314 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1315  if (operands.size() < 2) {
1316  return emitError(unknownLoc,
1317  "OpConstantComposite must have type <id> and result <id>");
1318  }
1319  if (operands.size() < 3) {
1320  return emitError(unknownLoc,
1321  "OpConstantComposite must have at least 1 parameter");
1322  }
1323 
1324  Type resultType = getType(operands[0]);
1325  if (!resultType) {
1326  return emitError(unknownLoc, "undefined result type from <id> ")
1327  << operands[0];
1328  }
1329 
1330  auto resultID = operands[1];
1331  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1332 
1333  SmallVector<Attribute, 4> elements;
1334  elements.reserve(operands.size() - 2);
1335  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1336  auto elementInfo = getSpecConstant(operands[i]);
1337  elements.push_back(SymbolRefAttr::get(elementInfo));
1338  }
1339 
1340  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1341  unknownLoc, TypeAttr::get(resultType), symName,
1342  opBuilder.getArrayAttr(elements));
1343  specConstCompositeMap[resultID] = op;
1344 
1345  return success();
1346 }
1347 
1349 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1350  if (operands.size() < 3)
1351  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1352  "result <id>, and operand opcode");
1353 
1354  uint32_t resultTypeID = operands[0];
1355 
1356  if (!getType(resultTypeID))
1357  return emitError(unknownLoc, "undefined result type from <id> ")
1358  << resultTypeID;
1359 
1360  uint32_t resultID = operands[1];
1361  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1362  auto emplaceResult = specConstOperationMap.try_emplace(
1363  resultID,
1364  SpecConstOperationMaterializationInfo{
1365  enclosedOpcode, resultTypeID,
1366  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1367 
1368  if (!emplaceResult.second)
1369  return emitError(unknownLoc, "value with <id>: ")
1370  << resultID << " is probably defined before.";
1371 
1372  return success();
1373 }
1374 
1375 Value spirv::Deserializer::materializeSpecConstantOperation(
1376  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1377  ArrayRef<uint32_t> enclosedOpOperands) {
1378 
1379  Type resultType = getType(resultTypeID);
1380 
1381  // Instructions wrapped by OpSpecConstantOp need an ID for their
1382  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1383  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1384  // ID in that map is assigned to the result of the enclosed instruction. Note
1385  // that there is no need to update this fake ID since we only need to
1386  // reference the created Value for the enclosed op from the spv::YieldOp
1387  // created later in this method (both of which are the only values in their
1388  // region: the SpecConstantOperation's region). If we encounter another
1389  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1390  // previous Value assigned to it isn't visible in the current scope anyway.
1391  DenseMap<uint32_t, Value> newValueMap;
1392  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1393  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1394 
1395  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1396  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1397  enclosedOpResultTypeAndOperands.push_back(fakeID);
1398  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1399  enclosedOpOperands.end());
1400 
1401  // Process enclosed instruction before creating the enclosing
1402  // specConstantOperation (and its region). This way, references to constants,
1403  // global variables, and spec constants will be materialized outside the new
1404  // op's region. For more info, see Deserializer::getValue's implementation.
1405  if (failed(
1406  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1407  return Value();
1408 
1409  // Since the enclosed op is emitted in the current block, split it in a
1410  // separate new block.
1411  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1412 
1413  auto loc = createFileLineColLoc(opBuilder);
1414  auto specConstOperationOp =
1415  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1416 
1417  Region &body = specConstOperationOp.getBody();
1418  // Move the new block into SpecConstantOperation's body.
1419  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1420  Region::iterator(enclosedBlock));
1421  Block &block = body.back();
1422 
1423  // RAII guard to reset the insertion point to the module's region after
1424  // deserializing the body of the specConstantOperation.
1425  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1426  opBuilder.setInsertionPointToEnd(&block);
1427 
1428  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1429  return specConstOperationOp.getResult();
1430 }
1431 
1433 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1434  if (operands.size() != 2) {
1435  return emitError(unknownLoc,
1436  "OpConstantNull must have type <id> and result <id>");
1437  }
1438 
1439  Type resultType = getType(operands[0]);
1440  if (!resultType) {
1441  return emitError(unknownLoc, "undefined result type from <id> ")
1442  << operands[0];
1443  }
1444 
1445  auto resultID = operands[1];
1446  if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
1447  auto attr = opBuilder.getZeroAttr(resultType);
1448  // For normal constants, we just record the attribute (and its type) for
1449  // later materialization at use sites.
1450  constantMap.try_emplace(resultID, attr, resultType);
1451  return success();
1452  }
1453 
1454  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1455  << resultType;
1456 }
1457 
1458 //===----------------------------------------------------------------------===//
1459 // Control flow
1460 //===----------------------------------------------------------------------===//
1461 
1462 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1463  if (auto *block = getBlock(id)) {
1464  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1465  << " @ " << block << "\n");
1466  return block;
1467  }
1468 
1469  // We don't know where this block will be placed finally (in a
1470  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1471  // function for now and sort out the proper place later.
1472  auto *block = curFunction->addBlock();
1473  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1474  << " @ " << block << "\n");
1475  return blockMap[id] = block;
1476 }
1477 
1478 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1479  if (!curBlock) {
1480  return emitError(unknownLoc, "OpBranch must appear inside a block");
1481  }
1482 
1483  if (operands.size() != 1) {
1484  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1485  }
1486 
1487  auto *target = getOrCreateBlock(operands[0]);
1488  auto loc = createFileLineColLoc(opBuilder);
1489  // The preceding instruction for the OpBranch instruction could be an
1490  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1491  // the same OpLine information.
1492  opBuilder.create<spirv::BranchOp>(loc, target);
1493 
1494  clearDebugLine();
1495  return success();
1496 }
1497 
1499 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1500  if (!curBlock) {
1501  return emitError(unknownLoc,
1502  "OpBranchConditional must appear inside a block");
1503  }
1504 
1505  if (operands.size() != 3 && operands.size() != 5) {
1506  return emitError(unknownLoc,
1507  "OpBranchConditional must have condition, true label, "
1508  "false label, and optionally two branch weights");
1509  }
1510 
1511  auto condition = getValue(operands[0]);
1512  auto *trueBlock = getOrCreateBlock(operands[1]);
1513  auto *falseBlock = getOrCreateBlock(operands[2]);
1514 
1516  if (operands.size() == 5) {
1517  weights = std::make_pair(operands[3], operands[4]);
1518  }
1519  // The preceding instruction for the OpBranchConditional instruction could be
1520  // an OpSelectionMerge instruction, in this case they will have the same
1521  // OpLine information.
1522  auto loc = createFileLineColLoc(opBuilder);
1523  opBuilder.create<spirv::BranchConditionalOp>(
1524  loc, condition, trueBlock,
1525  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1526  /*falseArguments=*/ArrayRef<Value>(), weights);
1527 
1528  clearDebugLine();
1529  return success();
1530 }
1531 
1532 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1533  if (!curFunction) {
1534  return emitError(unknownLoc, "OpLabel must appear inside a function");
1535  }
1536 
1537  if (operands.size() != 1) {
1538  return emitError(unknownLoc, "OpLabel should only have result <id>");
1539  }
1540 
1541  auto labelID = operands[0];
1542  // We may have forward declared this block.
1543  auto *block = getOrCreateBlock(labelID);
1544  LLVM_DEBUG(logger.startLine()
1545  << "[block] populating block " << block << "\n");
1546  // If we have seen this block, make sure it was just a forward declaration.
1547  assert(block->empty() && "re-deserialize the same block!");
1548 
1549  opBuilder.setInsertionPointToStart(block);
1550  blockMap[labelID] = curBlock = block;
1551 
1552  return success();
1553 }
1554 
1556 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1557  if (!curBlock) {
1558  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1559  }
1560 
1561  if (operands.size() < 2) {
1562  return emitError(
1563  unknownLoc,
1564  "OpSelectionMerge must specify merge target and selection control");
1565  }
1566 
1567  auto *mergeBlock = getOrCreateBlock(operands[0]);
1568  auto loc = createFileLineColLoc(opBuilder);
1569  auto selectionControl = operands[1];
1570 
1571  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1572  .second) {
1573  return emitError(
1574  unknownLoc,
1575  "a block cannot have more than one OpSelectionMerge instruction");
1576  }
1577 
1578  return success();
1579 }
1580 
1582 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1583  if (!curBlock) {
1584  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1585  }
1586 
1587  if (operands.size() < 3) {
1588  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1589  "continue target and loop control");
1590  }
1591 
1592  auto *mergeBlock = getOrCreateBlock(operands[0]);
1593  auto *continueBlock = getOrCreateBlock(operands[1]);
1594  auto loc = createFileLineColLoc(opBuilder);
1595  uint32_t loopControl = operands[2];
1596 
1597  if (!blockMergeInfo
1598  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1599  .second) {
1600  return emitError(
1601  unknownLoc,
1602  "a block cannot have more than one OpLoopMerge instruction");
1603  }
1604 
1605  return success();
1606 }
1607 
1608 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1609  if (!curBlock) {
1610  return emitError(unknownLoc, "OpPhi must appear in a block");
1611  }
1612 
1613  if (operands.size() < 4) {
1614  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1615  "and variable-parent pairs");
1616  }
1617 
1618  // Create a block argument for this OpPhi instruction.
1619  Type blockArgType = getType(operands[0]);
1620  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1621  valueMap[operands[1]] = blockArg;
1622  LLVM_DEBUG(logger.startLine()
1623  << "[phi] created block argument " << blockArg
1624  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1625 
1626  // For each (value, predecessor) pair, insert the value to the predecessor's
1627  // blockPhiInfo entry so later we can fix the block argument there.
1628  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1629  uint32_t value = operands[i];
1630  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1631  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1632  blockPhiInfo[predecessorTargetPair].push_back(value);
1633  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1634  << " with arg id = " << value << "\n");
1635  }
1636 
1637  return success();
1638 }
1639 
1640 namespace {
1641 /// A class for putting all blocks in a structured selection/loop in a
1642 /// spirv.mlir.selection/spirv.mlir.loop op.
1643 class ControlFlowStructurizer {
1644 public:
1645 #ifndef NDEBUG
1646  ControlFlowStructurizer(Location loc, uint32_t control,
1647  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1648  Block *merge, Block *cont,
1649  llvm::ScopedPrinter &logger)
1650  : location(loc), control(control), blockMergeInfo(mergeInfo),
1651  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1652  logger(logger) {}
1653 #else
1654  ControlFlowStructurizer(Location loc, uint32_t control,
1655  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1656  Block *merge, Block *cont)
1657  : location(loc), control(control), blockMergeInfo(mergeInfo),
1658  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1659 #endif
1660 
1661  /// Structurizes the loop at the given `headerBlock`.
1662  ///
1663  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1664  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1665  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1666  /// method will also update `mergeInfo` by remapping all blocks inside to the
1667  /// newly cloned ones inside structured control flow op's regions.
1668  LogicalResult structurize();
1669 
1670 private:
1671  /// Creates a new spirv.mlir.selection op at the beginning of the
1672  /// `mergeBlock`.
1673  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1674 
1675  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1676  spirv::LoopOp createLoopOp(uint32_t loopControl);
1677 
1678  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1679  void collectBlocksInConstruct();
1680 
1681  Location location;
1682  uint32_t control;
1683 
1684  spirv::BlockMergeInfoMap &blockMergeInfo;
1685 
1686  Block *headerBlock;
1687  Block *mergeBlock;
1688  Block *continueBlock; // nullptr for spirv.mlir.selection
1689 
1690  SetVector<Block *> constructBlocks;
1691 
1692 #ifndef NDEBUG
1693  /// A logger used to emit information during the deserialzation process.
1694  llvm::ScopedPrinter &logger;
1695 #endif
1696 };
1697 } // namespace
1698 
1699 spirv::SelectionOp
1700 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1701  // Create a builder and set the insertion point to the beginning of the
1702  // merge block so that the newly created SelectionOp will be inserted there.
1703  OpBuilder builder(&mergeBlock->front());
1704 
1705  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1706  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1707  selectionOp.addMergeBlock();
1708 
1709  return selectionOp;
1710 }
1711 
1712 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1713  // Create a builder and set the insertion point to the beginning of the
1714  // merge block so that the newly created LoopOp will be inserted there.
1715  OpBuilder builder(&mergeBlock->front());
1716 
1717  auto control = static_cast<spirv::LoopControl>(loopControl);
1718  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1719  loopOp.addEntryAndMergeBlock();
1720 
1721  return loopOp;
1722 }
1723 
1724 void ControlFlowStructurizer::collectBlocksInConstruct() {
1725  assert(constructBlocks.empty() && "expected empty constructBlocks");
1726 
1727  // Put the header block in the work list first.
1728  constructBlocks.insert(headerBlock);
1729 
1730  // For each item in the work list, add its successors excluding the merge
1731  // block.
1732  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1733  for (auto *successor : constructBlocks[i]->getSuccessors())
1734  if (successor != mergeBlock)
1735  constructBlocks.insert(successor);
1736  }
1737 }
1738 
1739 LogicalResult ControlFlowStructurizer::structurize() {
1740  Operation *op = nullptr;
1741  bool isLoop = continueBlock != nullptr;
1742  if (isLoop) {
1743  if (auto loopOp = createLoopOp(control))
1744  op = loopOp.getOperation();
1745  } else {
1746  if (auto selectionOp = createSelectionOp(control))
1747  op = selectionOp.getOperation();
1748  }
1749  if (!op)
1750  return failure();
1751  Region &body = op->getRegion(0);
1752 
1753  BlockAndValueMapping mapper;
1754  // All references to the old merge block should be directed to the
1755  // selection/loop merge block in the SelectionOp/LoopOp's region.
1756  mapper.map(mergeBlock, &body.back());
1757 
1758  collectBlocksInConstruct();
1759 
1760  // We've identified all blocks belonging to the selection/loop's region. Now
1761  // need to "move" them into the selection/loop. Instead of really moving the
1762  // blocks, in the following we copy them and remap all values and branches.
1763  // This is because:
1764  // * Inserting a block into a region requires the block not in any region
1765  // before. But selections/loops can nest so we can create selection/loop ops
1766  // in a nested manner, which means some blocks may already be in a
1767  // selection/loop region when to be moved again.
1768  // * It's much trickier to fix up the branches into and out of the loop's
1769  // region: we need to treat not-moved blocks and moved blocks differently:
1770  // Not-moved blocks jumping to the loop header block need to jump to the
1771  // merge point containing the new loop op but not the loop continue block's
1772  // back edge. Moved blocks jumping out of the loop need to jump to the
1773  // merge block inside the loop region but not other not-moved blocks.
1774  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1775  // logic.
1776 
1777  // Create a corresponding block in the SelectionOp/LoopOp's region for each
1778  // block in this loop construct.
1779  OpBuilder builder(body);
1780  for (auto *block : constructBlocks) {
1781  // Create a block and insert it before the selection/loop merge block in the
1782  // SelectionOp/LoopOp's region.
1783  auto *newBlock = builder.createBlock(&body.back());
1784  mapper.map(block, newBlock);
1785  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1786  << " from block " << block << "\n");
1787  if (!isFnEntryBlock(block)) {
1788  for (BlockArgument blockArg : block->getArguments()) {
1789  auto newArg =
1790  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1791  mapper.map(blockArg, newArg);
1792  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1793  << blockArg << " to " << newArg << "\n");
1794  }
1795  } else {
1796  LLVM_DEBUG(logger.startLine()
1797  << "[cf] block " << block << " is a function entry block\n");
1798  }
1799 
1800  for (auto &op : *block)
1801  newBlock->push_back(op.clone(mapper));
1802  }
1803 
1804  // Go through all ops and remap the operands.
1805  auto remapOperands = [&](Operation *op) {
1806  for (auto &operand : op->getOpOperands())
1807  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1808  operand.set(mappedOp);
1809  for (auto &succOp : op->getBlockOperands())
1810  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1811  succOp.set(mappedOp);
1812  };
1813  for (auto &block : body)
1814  block.walk(remapOperands);
1815 
1816  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1817  // the selection/loop construct into its region. Next we need to fix the
1818  // connections between this new SelectionOp/LoopOp with existing blocks.
1819 
1820  // All existing incoming branches should go to the merge block, where the
1821  // SelectionOp/LoopOp resides right now.
1822  headerBlock->replaceAllUsesWith(mergeBlock);
1823 
1824  LLVM_DEBUG({
1825  logger.startLine() << "[cf] after cloning and fixing references:\n";
1826  headerBlock->getParentOp()->print(logger.getOStream());
1827  logger.startLine() << "\n";
1828  });
1829 
1830  if (isLoop) {
1831  if (!mergeBlock->args_empty()) {
1832  return mergeBlock->getParentOp()->emitError(
1833  "OpPhi in loop merge block unsupported");
1834  }
1835 
1836  // The loop header block may have block arguments. Since now we place the
1837  // loop op inside the old merge block, we need to make sure the old merge
1838  // block has the same block argument list.
1839  for (BlockArgument blockArg : headerBlock->getArguments())
1840  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1841 
1842  // If the loop header block has block arguments, make sure the spirv.Branch
1843  // op matches.
1844  SmallVector<Value, 4> blockArgs;
1845  if (!headerBlock->args_empty())
1846  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1847 
1848  // The loop entry block should have a unconditional branch jumping to the
1849  // loop header block.
1850  builder.setInsertionPointToEnd(&body.front());
1851  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1852  ArrayRef<Value>(blockArgs));
1853  }
1854 
1855  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1856  // cleaned up.
1857  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1858  // First we need to drop all operands' references inside all blocks. This is
1859  // needed because we can have blocks referencing SSA values from one another.
1860  for (auto *block : constructBlocks)
1861  block->dropAllReferences();
1862 
1863  // Check that whether some op in the to-be-erased blocks still has uses. Those
1864  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1865  // region. We cannot handle such cases given that once a value is sinked into
1866  // the SelectionOp/LoopOp's region, there is no escape for it:
1867  // SelectionOp/LooOp does not support yield values right now.
1868  for (auto *block : constructBlocks) {
1869  for (Operation &op : *block)
1870  if (!op.use_empty())
1871  return op.emitOpError(
1872  "failed control flow structurization: it has uses outside of the "
1873  "enclosing selection/loop construct");
1874  }
1875 
1876  // Then erase all old blocks.
1877  for (auto *block : constructBlocks) {
1878  // We've cloned all blocks belonging to this construct into the structured
1879  // control flow op's region. Among these blocks, some may compose another
1880  // selection/loop. If so, they will be recorded within blockMergeInfo.
1881  // We need to update the pointers there to the newly remapped ones so we can
1882  // continue structurizing them later.
1883  // TODO: The asserts in the following assumes input SPIR-V blob forms
1884  // correctly nested selection/loop constructs. We should relax this and
1885  // support error cases better.
1886  auto it = blockMergeInfo.find(block);
1887  if (it != blockMergeInfo.end()) {
1888  // Use the original location for nested selection/loop ops.
1889  Location loc = it->second.loc;
1890 
1891  Block *newHeader = mapper.lookupOrNull(block);
1892  if (!newHeader)
1893  return emitError(loc, "failed control flow structurization: nested "
1894  "loop header block should be remapped!");
1895 
1896  Block *newContinue = it->second.continueBlock;
1897  if (newContinue) {
1898  newContinue = mapper.lookupOrNull(newContinue);
1899  if (!newContinue)
1900  return emitError(loc, "failed control flow structurization: nested "
1901  "loop continue block should be remapped!");
1902  }
1903 
1904  Block *newMerge = it->second.mergeBlock;
1905  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
1906  newMerge = mappedTo;
1907 
1908  // The iterator should be erased before adding a new entry into
1909  // blockMergeInfo to avoid iterator invalidation.
1910  blockMergeInfo.erase(it);
1911  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
1912  newContinue);
1913  }
1914 
1915  // The structured selection/loop's entry block does not have arguments.
1916  // If the function's header block is also part of the structured control
1917  // flow, we cannot just simply erase it because it may contain arguments
1918  // matching the function signature and used by the cloned blocks.
1919  if (isFnEntryBlock(block)) {
1920  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
1921  << " to only contain a spirv.Branch op\n");
1922  // Still keep the function entry block for the potential block arguments,
1923  // but replace all ops inside with a branch to the merge block.
1924  block->clear();
1925  builder.setInsertionPointToEnd(block);
1926  builder.create<spirv::BranchOp>(location, mergeBlock);
1927  } else {
1928  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
1929  block->erase();
1930  }
1931  }
1932 
1933  LLVM_DEBUG(logger.startLine()
1934  << "[cf] after structurizing construct with header block "
1935  << headerBlock << ":\n"
1936  << *op << "\n");
1937 
1938  return success();
1939 }
1940 
1941 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
1942  LLVM_DEBUG({
1943  logger.startLine()
1944  << "//----- [phi] start wiring up block arguments -----//\n";
1945  logger.indent();
1946  });
1947 
1948  OpBuilder::InsertionGuard guard(opBuilder);
1949 
1950  for (const auto &info : blockPhiInfo) {
1951  Block *block = info.first.first;
1952  Block *target = info.first.second;
1953  const BlockPhiInfo &phiInfo = info.second;
1954  LLVM_DEBUG({
1955  logger.startLine() << "[phi] block " << block << "\n";
1956  logger.startLine() << "[phi] before creating block argument:\n";
1957  block->getParentOp()->print(logger.getOStream());
1958  logger.startLine() << "\n";
1959  });
1960 
1961  // Set insertion point to before this block's terminator early because we
1962  // may materialize ops via getValue() call.
1963  auto *op = block->getTerminator();
1964  opBuilder.setInsertionPoint(op);
1965 
1966  SmallVector<Value, 4> blockArgs;
1967  blockArgs.reserve(phiInfo.size());
1968  for (uint32_t valueId : phiInfo) {
1969  if (Value value = getValue(valueId)) {
1970  blockArgs.push_back(value);
1971  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
1972  << " id = " << valueId << "\n");
1973  } else {
1974  return emitError(unknownLoc, "OpPhi references undefined value!");
1975  }
1976  }
1977 
1978  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
1979  // Replace the previous branch op with a new one with block arguments.
1980  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
1981  blockArgs);
1982  branchOp.erase();
1983  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
1984  assert((branchCondOp.getTrueBlock() == target ||
1985  branchCondOp.getFalseBlock() == target) &&
1986  "expected target to be either the true or false target");
1987  if (target == branchCondOp.getTrueTarget())
1988  opBuilder.create<spirv::BranchConditionalOp>(
1989  branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
1990  branchCondOp.getFalseBlockArguments(),
1991  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
1992  branchCondOp.getFalseTarget());
1993  else
1994  opBuilder.create<spirv::BranchConditionalOp>(
1995  branchCondOp.getLoc(), branchCondOp.getCondition(),
1996  branchCondOp.getTrueBlockArguments(), blockArgs,
1997  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
1998  branchCondOp.getFalseBlock());
1999 
2000  branchCondOp.erase();
2001  } else {
2002  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2003  }
2004 
2005  LLVM_DEBUG({
2006  logger.startLine() << "[phi] after creating block argument:\n";
2007  block->getParentOp()->print(logger.getOStream());
2008  logger.startLine() << "\n";
2009  });
2010  }
2011  blockPhiInfo.clear();
2012 
2013  LLVM_DEBUG({
2014  logger.unindent();
2015  logger.startLine()
2016  << "//--- [phi] completed wiring up block arguments ---//\n";
2017  });
2018  return success();
2019 }
2020 
2021 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2022  LLVM_DEBUG({
2023  logger.startLine()
2024  << "//----- [cf] start structurizing control flow -----//\n";
2025  logger.indent();
2026  });
2027 
2028  while (!blockMergeInfo.empty()) {
2029  Block *headerBlock = blockMergeInfo.begin()->first;
2030  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2031 
2032  LLVM_DEBUG({
2033  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2034  headerBlock->print(logger.getOStream());
2035  logger.startLine() << "\n";
2036  });
2037 
2038  auto *mergeBlock = mergeInfo.mergeBlock;
2039  assert(mergeBlock && "merge block cannot be nullptr");
2040  if (!mergeBlock->args_empty())
2041  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2042  LLVM_DEBUG({
2043  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2044  mergeBlock->print(logger.getOStream());
2045  logger.startLine() << "\n";
2046  });
2047 
2048  auto *continueBlock = mergeInfo.continueBlock;
2049  LLVM_DEBUG(if (continueBlock) {
2050  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2051  continueBlock->print(logger.getOStream());
2052  logger.startLine() << "\n";
2053  });
2054  // Erase this case before calling into structurizer, who will update
2055  // blockMergeInfo.
2056  blockMergeInfo.erase(blockMergeInfo.begin());
2057  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2058  blockMergeInfo, headerBlock,
2059  mergeBlock, continueBlock
2060 #ifndef NDEBUG
2061  ,
2062  logger
2063 #endif
2064  );
2065  if (failed(structurizer.structurize()))
2066  return failure();
2067  }
2068 
2069  LLVM_DEBUG({
2070  logger.unindent();
2071  logger.startLine()
2072  << "//--- [cf] completed structurizing control flow ---//\n";
2073  });
2074  return success();
2075 }
2076 
2077 //===----------------------------------------------------------------------===//
2078 // Debug
2079 //===----------------------------------------------------------------------===//
2080 
2081 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2082  if (!debugLine)
2083  return unknownLoc;
2084 
2085  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2086  if (fileName.empty())
2087  fileName = "<unknown>";
2088  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2089  debugLine->column);
2090 }
2091 
2093 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2094  // According to SPIR-V spec:
2095  // "This location information applies to the instructions physically
2096  // following this instruction, up to the first occurrence of any of the
2097  // following: the next end of block, the next OpLine instruction, or the next
2098  // OpNoLine instruction."
2099  if (operands.size() != 3)
2100  return emitError(unknownLoc, "OpLine must have 3 operands");
2101  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2102  return success();
2103 }
2104 
2105 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2106 
2108 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2109  if (operands.size() < 2)
2110  return emitError(unknownLoc, "OpString needs at least 2 operands");
2111 
2112  if (!debugInfoMap.lookup(operands[0]).empty())
2113  return emitError(unknownLoc,
2114  "duplicate debug string found for result <id> ")
2115  << operands[0];
2116 
2117  unsigned wordIndex = 1;
2118  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2119  if (wordIndex != operands.size())
2120  return emitError(unknownLoc,
2121  "unexpected trailing words in OpString instruction");
2122 
2123  debugInfoMap[operands[0]] = debugString;
2124  return success();
2125 }
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static constexpr const bool value
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
Block * lookupOrNull(Block *from) const
Lookup a mapped value within the map.
This class represents an argument of a Block.
Definition: Value.h:296
Location getLoc() const
Return the location for this argument.
Definition: Value.h:311
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:137
Operation & back()
Definition: Block.h:141
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:54
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:291
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
void print(raw_ostream &os)
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
iterator begin()
Definition: Block.h:132
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
void push_back(Operation *op)
Definition: Block.h:138
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:243
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:633
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Operation * clone(BlockAndValueMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:558
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:49
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:490
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:499
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
BlockListType::iterator iterator
Definition: Region.h:52
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
U dyn_cast() const
Definition: Types.h:270
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:89
bool isa() const
Definition: Types.h:260
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:230
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:164
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:295
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:476
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:533
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:808
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.