MLIR  20.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context)
53  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54  module(createModuleOp()), opBuilder(module->getRegion())
55 #ifndef NDEBUG
56  ,
57  logger(llvm::dbgs())
58 #endif
59 {
60 }
61 
63  LLVM_DEBUG({
64  logger.resetIndent();
65  logger.startLine()
66  << "//+++---------- start deserialization ----------+++//\n";
67  });
68 
69  if (failed(processHeader()))
70  return failure();
71 
72  spirv::Opcode opcode = spirv::Opcode::OpNop;
73  ArrayRef<uint32_t> operands;
74  auto binarySize = binary.size();
75  while (curOffset < binarySize) {
76  // Slice the next instruction out and populate `opcode` and `operands`.
77  // Internally this also updates `curOffset`.
78  if (failed(sliceInstruction(opcode, operands)))
79  return failure();
80 
81  if (failed(processInstruction(opcode, operands)))
82  return failure();
83  }
84 
85  assert(curOffset == binarySize &&
86  "deserializer should never index beyond the binary end");
87 
88  for (auto &deferred : deferredInstructions) {
89  if (failed(processInstruction(deferred.first, deferred.second, false))) {
90  return failure();
91  }
92  }
93 
94  attachVCETriple();
95 
96  LLVM_DEBUG(logger.startLine()
97  << "//+++-------- completed deserialization --------+++//\n");
98  return success();
99 }
100 
102  return std::move(module);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Module structure
107 //===----------------------------------------------------------------------===//
108 
109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
110  OpBuilder builder(context);
111  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112  spirv::ModuleOp::build(builder, state);
113  return cast<spirv::ModuleOp>(Operation::create(state));
114 }
115 
116 LogicalResult spirv::Deserializer::processHeader() {
117  if (binary.size() < spirv::kHeaderWordCount)
118  return emitError(unknownLoc,
119  "SPIR-V binary module must have a 5-word header");
120 
121  if (binary[0] != spirv::kMagicNumber)
122  return emitError(unknownLoc, "incorrect magic number");
123 
124  // Version number bytes: 0 | major number | minor number | 0
125  uint32_t majorVersion = (binary[1] << 8) >> 24;
126  uint32_t minorVersion = (binary[1] << 16) >> 24;
127  if (majorVersion == 1) {
128  switch (minorVersion) {
129 #define MIN_VERSION_CASE(v) \
130  case v: \
131  version = spirv::Version::V_1_##v; \
132  break
133 
134  MIN_VERSION_CASE(0);
135  MIN_VERSION_CASE(1);
136  MIN_VERSION_CASE(2);
137  MIN_VERSION_CASE(3);
138  MIN_VERSION_CASE(4);
139  MIN_VERSION_CASE(5);
140 #undef MIN_VERSION_CASE
141  default:
142  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
143  << minorVersion;
144  }
145  } else {
146  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
147  << majorVersion;
148  }
149 
150  // TODO: generator number, bound, schema
151  curOffset = spirv::kHeaderWordCount;
152  return success();
153 }
154 
155 LogicalResult
156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
157  if (operands.size() != 1)
158  return emitError(unknownLoc, "OpMemoryModel must have one parameter");
159 
160  auto cap = spirv::symbolizeCapability(operands[0]);
161  if (!cap)
162  return emitError(unknownLoc, "unknown capability: ") << operands[0];
163 
164  capabilities.insert(*cap);
165  return success();
166 }
167 
168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
169  if (words.empty()) {
170  return emitError(
171  unknownLoc,
172  "OpExtension must have a literal string for the extension name");
173  }
174 
175  unsigned wordIndex = 0;
176  StringRef extName = decodeStringLiteral(words, wordIndex);
177  if (wordIndex != words.size())
178  return emitError(unknownLoc,
179  "unexpected trailing words in OpExtension instruction");
180  auto ext = spirv::symbolizeExtension(extName);
181  if (!ext)
182  return emitError(unknownLoc, "unknown extension: ") << extName;
183 
184  extensions.insert(*ext);
185  return success();
186 }
187 
188 LogicalResult
189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
190  if (words.size() < 2) {
191  return emitError(unknownLoc,
192  "OpExtInstImport must have a result <id> and a literal "
193  "string for the extended instruction set name");
194  }
195 
196  unsigned wordIndex = 1;
197  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
198  if (wordIndex != words.size()) {
199  return emitError(unknownLoc,
200  "unexpected trailing words in OpExtInstImport");
201  }
202  return success();
203 }
204 
205 void spirv::Deserializer::attachVCETriple() {
206  (*module)->setAttr(
207  spirv::ModuleOp::getVCETripleAttrName(),
208  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
209  extensions.getArrayRef(), context));
210 }
211 
212 LogicalResult
213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
214  if (operands.size() != 2)
215  return emitError(unknownLoc, "OpMemoryModel must have two operands");
216 
217  (*module)->setAttr(
218  module->getAddressingModelAttrName(),
219  opBuilder.getAttr<spirv::AddressingModelAttr>(
220  static_cast<spirv::AddressingModel>(operands.front())));
221 
222  (*module)->setAttr(module->getMemoryModelAttrName(),
223  opBuilder.getAttr<spirv::MemoryModelAttr>(
224  static_cast<spirv::MemoryModel>(operands.back())));
225 
226  return success();
227 }
228 
229 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
230  // TODO: This function should also be auto-generated. For now, since only a
231  // few decorations are processed/handled in a meaningful manner, going with a
232  // manual implementation.
233  if (words.size() < 2) {
234  return emitError(
235  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
236  }
237  auto decorationName =
238  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
239  if (decorationName.empty()) {
240  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
241  }
242  auto symbol = getSymbolDecoration(decorationName);
243  switch (static_cast<spirv::Decoration>(words[1])) {
244  case spirv::Decoration::FPFastMathMode:
245  if (words.size() != 3) {
246  return emitError(unknownLoc, "OpDecorate with ")
247  << decorationName << " needs a single integer literal";
248  }
249  decorations[words[0]].set(
250  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
251  static_cast<FPFastMathMode>(words[2])));
252  break;
253  case spirv::Decoration::DescriptorSet:
254  case spirv::Decoration::Binding:
255  if (words.size() != 3) {
256  return emitError(unknownLoc, "OpDecorate with ")
257  << decorationName << " needs a single integer literal";
258  }
259  decorations[words[0]].set(
260  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
261  break;
262  case spirv::Decoration::BuiltIn:
263  if (words.size() != 3) {
264  return emitError(unknownLoc, "OpDecorate with ")
265  << decorationName << " needs a single integer literal";
266  }
267  decorations[words[0]].set(
268  symbol, opBuilder.getStringAttr(
269  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
270  break;
271  case spirv::Decoration::ArrayStride:
272  if (words.size() != 3) {
273  return emitError(unknownLoc, "OpDecorate with ")
274  << decorationName << " needs a single integer literal";
275  }
276  typeDecorations[words[0]] = words[2];
277  break;
278  case spirv::Decoration::LinkageAttributes: {
279  if (words.size() < 4) {
280  return emitError(unknownLoc, "OpDecorate with ")
281  << decorationName
282  << " needs at least 1 string and 1 integer literal";
283  }
284  // LinkageAttributes has two parameters ["linkageName", linkageType]
285  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
286  // "linkageName" is a stringliteral encoded as uint32_t,
287  // hence the size of name is variable length which results in words.size()
288  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
289  // 3 + ceildiv(strlen(name), 4).
290  unsigned wordIndex = 2;
291  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
292  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
293  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
294  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
295  StringAttr::get(context, linkageName), linkageTypeAttr);
296  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
297  break;
298  }
299  case spirv::Decoration::Aliased:
300  case spirv::Decoration::AliasedPointer:
301  case spirv::Decoration::Block:
302  case spirv::Decoration::BufferBlock:
303  case spirv::Decoration::Flat:
304  case spirv::Decoration::NonReadable:
305  case spirv::Decoration::NonWritable:
306  case spirv::Decoration::NoPerspective:
307  case spirv::Decoration::NoSignedWrap:
308  case spirv::Decoration::NoUnsignedWrap:
309  case spirv::Decoration::RelaxedPrecision:
310  case spirv::Decoration::Restrict:
311  case spirv::Decoration::RestrictPointer:
312  case spirv::Decoration::NoContraction:
313  if (words.size() != 2) {
314  return emitError(unknownLoc, "OpDecoration with ")
315  << decorationName << "needs a single target <id>";
316  }
317  // Block decoration does not affect spirv.struct type, but is still stored
318  // for verification.
319  // TODO: Update StructType to contain this information since
320  // it is needed for many validation rules.
321  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
322  break;
323  case spirv::Decoration::Location:
324  case spirv::Decoration::SpecId:
325  if (words.size() != 3) {
326  return emitError(unknownLoc, "OpDecoration with ")
327  << decorationName << "needs a single integer literal";
328  }
329  decorations[words[0]].set(
330  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
331  break;
332  default:
333  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
334  }
335  return success();
336 }
337 
338 LogicalResult
339 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
340  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
341  if (words.size() < 3) {
342  return emitError(unknownLoc,
343  "OpMemberDecorate must have at least 3 operands");
344  }
345 
346  auto decoration = static_cast<spirv::Decoration>(words[2]);
347  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
348  return emitError(unknownLoc,
349  " missing offset specification in OpMemberDecorate with "
350  "Offset decoration");
351  }
352  ArrayRef<uint32_t> decorationOperands;
353  if (words.size() > 3) {
354  decorationOperands = words.slice(3);
355  }
356  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
357  return success();
358 }
359 
360 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
361  if (words.size() < 3) {
362  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
363  }
364  unsigned wordIndex = 2;
365  auto name = decodeStringLiteral(words, wordIndex);
366  if (wordIndex != words.size()) {
367  return emitError(unknownLoc,
368  "unexpected trailing words in OpMemberName instruction");
369  }
370  memberNameMap[words[0]][words[1]] = name;
371  return success();
372 }
373 
374 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
375  uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
376  if (!decorations.contains(argID)) {
377  argAttrs[argIndex] = DictionaryAttr::get(context, {});
378  return success();
379  }
380 
381  spirv::DecorationAttr foundDecorationAttr;
382  for (NamedAttribute decAttr : decorations[argID]) {
383  for (auto decoration :
384  {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
385  spirv::Decoration::AliasedPointer,
386  spirv::Decoration::RestrictPointer}) {
387 
388  if (decAttr.getName() !=
389  getSymbolDecoration(stringifyDecoration(decoration)))
390  continue;
391 
392  if (foundDecorationAttr)
393  return emitError(unknownLoc,
394  "more than one Aliased/Restrict decorations for "
395  "function argument with result <id> ")
396  << argID;
397 
398  foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
399  break;
400  }
401  }
402 
403  if (!foundDecorationAttr)
404  return emitError(unknownLoc, "unimplemented decoration support for "
405  "function argument with result <id> ")
406  << argID;
407 
408  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
409  foundDecorationAttr);
410  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
411  return success();
412 }
413 
414 LogicalResult
415 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
416  if (curFunction) {
417  return emitError(unknownLoc, "found function inside function");
418  }
419 
420  // Get the result type
421  if (operands.size() != 4) {
422  return emitError(unknownLoc, "OpFunction must have 4 parameters");
423  }
424  Type resultType = getType(operands[0]);
425  if (!resultType) {
426  return emitError(unknownLoc, "undefined result type from <id> ")
427  << operands[0];
428  }
429 
430  uint32_t fnID = operands[1];
431  if (funcMap.count(fnID)) {
432  return emitError(unknownLoc, "duplicate function definition/declaration");
433  }
434 
435  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
436  if (!fnControl) {
437  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
438  }
439 
440  Type fnType = getType(operands[3]);
441  if (!fnType || !isa<FunctionType>(fnType)) {
442  return emitError(unknownLoc, "unknown function type from <id> ")
443  << operands[3];
444  }
445  auto functionType = cast<FunctionType>(fnType);
446 
447  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
448  (functionType.getNumResults() == 1 &&
449  functionType.getResult(0) != resultType)) {
450  return emitError(unknownLoc, "mismatch in function type ")
451  << functionType << " and return type " << resultType << " specified";
452  }
453 
454  std::string fnName = getFunctionSymbol(fnID);
455  auto funcOp = opBuilder.create<spirv::FuncOp>(
456  unknownLoc, fnName, functionType, fnControl.value());
457  // Processing other function attributes.
458  if (decorations.count(fnID)) {
459  for (auto attr : decorations[fnID].getAttrs()) {
460  funcOp->setAttr(attr.getName(), attr.getValue());
461  }
462  }
463  curFunction = funcMap[fnID] = funcOp;
464  auto *entryBlock = funcOp.addEntryBlock();
465  LLVM_DEBUG({
466  logger.startLine()
467  << "//===-------------------------------------------===//\n";
468  logger.startLine() << "[fn] name: " << fnName << "\n";
469  logger.startLine() << "[fn] type: " << fnType << "\n";
470  logger.startLine() << "[fn] ID: " << fnID << "\n";
471  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
472  logger.indent();
473  });
474 
475  SmallVector<Attribute> argAttrs;
476  argAttrs.resize(functionType.getNumInputs());
477 
478  // Parse the op argument instructions
479  if (functionType.getNumInputs()) {
480  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
481  auto argType = functionType.getInput(i);
482  spirv::Opcode opcode = spirv::Opcode::OpNop;
483  ArrayRef<uint32_t> operands;
484  if (failed(sliceInstruction(opcode, operands,
485  spirv::Opcode::OpFunctionParameter))) {
486  return failure();
487  }
488  if (opcode != spirv::Opcode::OpFunctionParameter) {
489  return emitError(
490  unknownLoc,
491  "missing OpFunctionParameter instruction for argument ")
492  << i;
493  }
494  if (operands.size() != 2) {
495  return emitError(
496  unknownLoc,
497  "expected result type and result <id> for OpFunctionParameter");
498  }
499  auto argDefinedType = getType(operands[0]);
500  if (!argDefinedType || argDefinedType != argType) {
501  return emitError(unknownLoc,
502  "mismatch in argument type between function type "
503  "definition ")
504  << functionType << " and argument type definition "
505  << argDefinedType << " at argument " << i;
506  }
507  if (getValue(operands[1])) {
508  return emitError(unknownLoc, "duplicate definition of result <id> ")
509  << operands[1];
510  }
511  if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
512  return failure();
513  }
514 
515  auto argValue = funcOp.getArgument(i);
516  valueMap[operands[1]] = argValue;
517  }
518  }
519 
520  if (llvm::any_of(argAttrs, [](Attribute attr) {
521  auto argAttr = cast<DictionaryAttr>(attr);
522  return !argAttr.empty();
523  }))
524  funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
525 
526  // entryBlock is needed to access the arguments, Once that is done, we can
527  // erase the block for functions with 'Import' LinkageAttributes, since these
528  // are essentially function declarations, so they have no body.
529  auto linkageAttr = funcOp.getLinkageAttributes();
530  auto hasImportLinkage =
531  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
532  spirv::LinkageType::Import);
533  if (hasImportLinkage)
534  funcOp.eraseBody();
535 
536  // RAII guard to reset the insertion point to the module's region after
537  // deserializing the body of this function.
538  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
539 
540  spirv::Opcode opcode = spirv::Opcode::OpNop;
541  ArrayRef<uint32_t> instOperands;
542 
543  // Special handling for the entry block. We need to make sure it starts with
544  // an OpLabel instruction. The entry block takes the same parameters as the
545  // function. All other blocks do not take any parameter. We have already
546  // created the entry block, here we need to register it to the correct label
547  // <id>.
548  if (failed(sliceInstruction(opcode, instOperands,
549  spirv::Opcode::OpFunctionEnd))) {
550  return failure();
551  }
552  if (opcode == spirv::Opcode::OpFunctionEnd) {
553  return processFunctionEnd(instOperands);
554  }
555  if (opcode != spirv::Opcode::OpLabel) {
556  return emitError(unknownLoc, "a basic block must start with OpLabel");
557  }
558  if (instOperands.size() != 1) {
559  return emitError(unknownLoc, "OpLabel should only have result <id>");
560  }
561  blockMap[instOperands[0]] = entryBlock;
562  if (failed(processLabel(instOperands))) {
563  return failure();
564  }
565 
566  // Then process all the other instructions in the function until we hit
567  // OpFunctionEnd.
568  while (succeeded(sliceInstruction(opcode, instOperands,
569  spirv::Opcode::OpFunctionEnd)) &&
570  opcode != spirv::Opcode::OpFunctionEnd) {
571  if (failed(processInstruction(opcode, instOperands))) {
572  return failure();
573  }
574  }
575  if (opcode != spirv::Opcode::OpFunctionEnd) {
576  return failure();
577  }
578 
579  return processFunctionEnd(instOperands);
580 }
581 
582 LogicalResult
583 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
584  // Process OpFunctionEnd.
585  if (!operands.empty()) {
586  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
587  }
588 
589  // Wire up block arguments from OpPhi instructions.
590  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
591  // ops.
592  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
593  return failure();
594  }
595 
596  curBlock = nullptr;
597  curFunction = std::nullopt;
598 
599  LLVM_DEBUG({
600  logger.unindent();
601  logger.startLine()
602  << "//===-------------------------------------------===//\n";
603  });
604  return success();
605 }
606 
607 std::optional<std::pair<Attribute, Type>>
608 spirv::Deserializer::getConstant(uint32_t id) {
609  auto constIt = constantMap.find(id);
610  if (constIt == constantMap.end())
611  return std::nullopt;
612  return constIt->getSecond();
613 }
614 
615 std::optional<spirv::SpecConstOperationMaterializationInfo>
616 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
617  auto constIt = specConstOperationMap.find(id);
618  if (constIt == specConstOperationMap.end())
619  return std::nullopt;
620  return constIt->getSecond();
621 }
622 
623 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
624  auto funcName = nameMap.lookup(id).str();
625  if (funcName.empty()) {
626  funcName = "spirv_fn_" + std::to_string(id);
627  }
628  return funcName;
629 }
630 
631 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
632  auto constName = nameMap.lookup(id).str();
633  if (constName.empty()) {
634  constName = "spirv_spec_const_" + std::to_string(id);
635  }
636  return constName;
637 }
638 
639 spirv::SpecConstantOp
640 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
641  TypedAttr defaultValue) {
642  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
643  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
644  defaultValue);
645  if (decorations.count(resultID)) {
646  for (auto attr : decorations[resultID].getAttrs())
647  op->setAttr(attr.getName(), attr.getValue());
648  }
649  specConstMap[resultID] = op;
650  return op;
651 }
652 
653 LogicalResult
654 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
655  unsigned wordIndex = 0;
656  if (operands.size() < 3) {
657  return emitError(
658  unknownLoc,
659  "OpVariable needs at least 3 operands, type, <id> and storage class");
660  }
661 
662  // Result Type.
663  auto type = getType(operands[wordIndex]);
664  if (!type) {
665  return emitError(unknownLoc, "unknown result type <id> : ")
666  << operands[wordIndex];
667  }
668  auto ptrType = dyn_cast<spirv::PointerType>(type);
669  if (!ptrType) {
670  return emitError(unknownLoc,
671  "expected a result type <id> to be a spirv.ptr, found : ")
672  << type;
673  }
674  wordIndex++;
675 
676  // Result <id>.
677  auto variableID = operands[wordIndex];
678  auto variableName = nameMap.lookup(variableID).str();
679  if (variableName.empty()) {
680  variableName = "spirv_var_" + std::to_string(variableID);
681  }
682  wordIndex++;
683 
684  // Storage class.
685  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
686  if (ptrType.getStorageClass() != storageClass) {
687  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
688  << type << " and that specified in OpVariable instruction : "
689  << stringifyStorageClass(storageClass);
690  }
691  wordIndex++;
692 
693  // Initializer.
694  FlatSymbolRefAttr initializer = nullptr;
695 
696  if (wordIndex < operands.size()) {
697  Operation *op = nullptr;
698 
699  if (auto initOp = getGlobalVariable(operands[wordIndex]))
700  op = initOp;
701  else if (auto initOp = getSpecConstant(operands[wordIndex]))
702  op = initOp;
703  else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
704  op = initOp;
705  else
706  return emitError(unknownLoc, "unknown <id> ")
707  << operands[wordIndex] << "used as initializer";
708 
709  initializer = SymbolRefAttr::get(op);
710  wordIndex++;
711  }
712  if (wordIndex != operands.size()) {
713  return emitError(unknownLoc,
714  "found more operands than expected when deserializing "
715  "OpVariable instruction, only ")
716  << wordIndex << " of " << operands.size() << " processed";
717  }
718  auto loc = createFileLineColLoc(opBuilder);
719  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
720  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
721  initializer);
722 
723  // Decorations.
724  if (decorations.count(variableID)) {
725  for (auto attr : decorations[variableID].getAttrs())
726  varOp->setAttr(attr.getName(), attr.getValue());
727  }
728  globalVariableMap[variableID] = varOp;
729  return success();
730 }
731 
732 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
733  auto constInfo = getConstant(id);
734  if (!constInfo) {
735  return nullptr;
736  }
737  return dyn_cast<IntegerAttr>(constInfo->first);
738 }
739 
740 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
741  if (operands.size() < 2) {
742  return emitError(unknownLoc, "OpName needs at least 2 operands");
743  }
744  if (!nameMap.lookup(operands[0]).empty()) {
745  return emitError(unknownLoc, "duplicate name found for result <id> ")
746  << operands[0];
747  }
748  unsigned wordIndex = 1;
749  StringRef name = decodeStringLiteral(operands, wordIndex);
750  if (wordIndex != operands.size()) {
751  return emitError(unknownLoc,
752  "unexpected trailing words in OpName instruction");
753  }
754  nameMap[operands[0]] = name;
755  return success();
756 }
757 
758 //===----------------------------------------------------------------------===//
759 // Type
760 //===----------------------------------------------------------------------===//
761 
762 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
763  ArrayRef<uint32_t> operands) {
764  if (operands.empty()) {
765  return emitError(unknownLoc, "type instruction with opcode ")
766  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
767  }
768 
769  /// TODO: Types might be forward declared in some instructions and need to be
770  /// handled appropriately.
771  if (typeMap.count(operands[0])) {
772  return emitError(unknownLoc, "duplicate definition for result <id> ")
773  << operands[0];
774  }
775 
776  switch (opcode) {
777  case spirv::Opcode::OpTypeVoid:
778  if (operands.size() != 1)
779  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
780  typeMap[operands[0]] = opBuilder.getNoneType();
781  break;
782  case spirv::Opcode::OpTypeBool:
783  if (operands.size() != 1)
784  return emitError(unknownLoc, "OpTypeBool must have no parameters");
785  typeMap[operands[0]] = opBuilder.getI1Type();
786  break;
787  case spirv::Opcode::OpTypeInt: {
788  if (operands.size() != 3)
789  return emitError(
790  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
791 
792  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
793  // to preserve or validate.
794  // 0 indicates unsigned, or no signedness semantics
795  // 1 indicates signed semantics."
796  //
797  // So we cannot differentiate signless and unsigned integers; always use
798  // signless semantics for such cases.
799  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
800  : IntegerType::SignednessSemantics::Signless;
801  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
802  } break;
803  case spirv::Opcode::OpTypeFloat: {
804  if (operands.size() != 2)
805  return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
806 
807  Type floatTy;
808  switch (operands[1]) {
809  case 16:
810  floatTy = opBuilder.getF16Type();
811  break;
812  case 32:
813  floatTy = opBuilder.getF32Type();
814  break;
815  case 64:
816  floatTy = opBuilder.getF64Type();
817  break;
818  default:
819  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
820  << operands[1];
821  }
822  typeMap[operands[0]] = floatTy;
823  } break;
824  case spirv::Opcode::OpTypeVector: {
825  if (operands.size() != 3) {
826  return emitError(
827  unknownLoc,
828  "OpTypeVector must have element type and count parameters");
829  }
830  Type elementTy = getType(operands[1]);
831  if (!elementTy) {
832  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
833  << operands[1];
834  }
835  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
836  } break;
837  case spirv::Opcode::OpTypePointer: {
838  return processOpTypePointer(operands);
839  } break;
840  case spirv::Opcode::OpTypeArray:
841  return processArrayType(operands);
842  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
843  return processCooperativeMatrixTypeKHR(operands);
844  case spirv::Opcode::OpTypeFunction:
845  return processFunctionType(operands);
846  case spirv::Opcode::OpTypeJointMatrixINTEL:
847  return processJointMatrixType(operands);
848  case spirv::Opcode::OpTypeImage:
849  return processImageType(operands);
850  case spirv::Opcode::OpTypeSampledImage:
851  return processSampledImageType(operands);
852  case spirv::Opcode::OpTypeRuntimeArray:
853  return processRuntimeArrayType(operands);
854  case spirv::Opcode::OpTypeStruct:
855  return processStructType(operands);
856  case spirv::Opcode::OpTypeMatrix:
857  return processMatrixType(operands);
858  default:
859  return emitError(unknownLoc, "unhandled type instruction");
860  }
861  return success();
862 }
863 
864 LogicalResult
865 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
866  if (operands.size() != 3)
867  return emitError(unknownLoc, "OpTypePointer must have two parameters");
868 
869  auto pointeeType = getType(operands[2]);
870  if (!pointeeType)
871  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
872  << operands[2];
873 
874  uint32_t typePointerID = operands[0];
875  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
876  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
877 
878  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
879  deferredStructIt != std::end(deferredStructTypesInfos);) {
880  for (auto *unresolvedMemberIt =
881  std::begin(deferredStructIt->unresolvedMemberTypes);
882  unresolvedMemberIt !=
883  std::end(deferredStructIt->unresolvedMemberTypes);) {
884  if (unresolvedMemberIt->first == typePointerID) {
885  // The newly constructed pointer type can resolve one of the
886  // deferred struct type members; update the memberTypes list and
887  // clean the unresolvedMemberTypes list accordingly.
888  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
889  typeMap[typePointerID];
890  unresolvedMemberIt =
891  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
892  } else {
893  ++unresolvedMemberIt;
894  }
895  }
896 
897  if (deferredStructIt->unresolvedMemberTypes.empty()) {
898  // All deferred struct type members are now resolved, set the struct body.
899  auto structType = deferredStructIt->deferredStructType;
900 
901  assert(structType && "expected a spirv::StructType");
902  assert(structType.isIdentified() && "expected an indentified struct");
903 
904  if (failed(structType.trySetBody(
905  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
906  deferredStructIt->memberDecorationsInfo)))
907  return failure();
908 
909  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
910  } else {
911  ++deferredStructIt;
912  }
913  }
914 
915  return success();
916 }
917 
918 LogicalResult
919 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
920  if (operands.size() != 3) {
921  return emitError(unknownLoc,
922  "OpTypeArray must have element type and count parameters");
923  }
924 
925  Type elementTy = getType(operands[1]);
926  if (!elementTy) {
927  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
928  << operands[1];
929  }
930 
931  unsigned count = 0;
932  // TODO: The count can also come frome a specialization constant.
933  auto countInfo = getConstant(operands[2]);
934  if (!countInfo) {
935  return emitError(unknownLoc, "OpTypeArray count <id> ")
936  << operands[2] << "can only come from normal constant right now";
937  }
938 
939  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
940  count = intVal.getValue().getZExtValue();
941  } else {
942  return emitError(unknownLoc, "OpTypeArray count must come from a "
943  "scalar integer constant instruction");
944  }
945 
946  typeMap[operands[0]] = spirv::ArrayType::get(
947  elementTy, count, typeDecorations.lookup(operands[0]));
948  return success();
949 }
950 
951 LogicalResult
952 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
953  assert(!operands.empty() && "No operands for processing function type");
954  if (operands.size() == 1) {
955  return emitError(unknownLoc, "missing return type for OpTypeFunction");
956  }
957  auto returnType = getType(operands[1]);
958  if (!returnType) {
959  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
960  }
961  SmallVector<Type, 1> argTypes;
962  for (size_t i = 2, e = operands.size(); i < e; ++i) {
963  auto ty = getType(operands[i]);
964  if (!ty) {
965  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
966  }
967  argTypes.push_back(ty);
968  }
969  ArrayRef<Type> returnTypes;
970  if (!isVoidType(returnType)) {
971  returnTypes = llvm::ArrayRef(returnType);
972  }
973  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
974  return success();
975 }
976 
977 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
978  ArrayRef<uint32_t> operands) {
979  if (operands.size() != 6) {
980  return emitError(unknownLoc,
981  "OpTypeCooperativeMatrixKHR must have element type, "
982  "scope, row and column parameters, and use");
983  }
984 
985  Type elementTy = getType(operands[1]);
986  if (!elementTy) {
987  return emitError(unknownLoc,
988  "OpTypeCooperativeMatrixKHR references undefined <id> ")
989  << operands[1];
990  }
991 
992  std::optional<spirv::Scope> scope =
993  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
994  if (!scope) {
995  return emitError(
996  unknownLoc,
997  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
998  << operands[2];
999  }
1000 
1001  unsigned rows = getConstantInt(operands[3]).getInt();
1002  unsigned columns = getConstantInt(operands[4]).getInt();
1003 
1004  std::optional<spirv::CooperativeMatrixUseKHR> use =
1005  spirv::symbolizeCooperativeMatrixUseKHR(
1006  getConstantInt(operands[5]).getInt());
1007  if (!use) {
1008  return emitError(
1009  unknownLoc,
1010  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1011  << operands[5];
1012  }
1013 
1014  typeMap[operands[0]] =
1015  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1016  return success();
1017 }
1018 
1019 LogicalResult
1020 spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
1021  if (operands.size() != 6) {
1022  return emitError(unknownLoc, "OpTypeJointMatrix must have element "
1023  "type and row x column parameters");
1024  }
1025 
1026  Type elementTy = getType(operands[1]);
1027  if (!elementTy) {
1028  return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
1029  << operands[1];
1030  }
1031 
1032  auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
1033  if (!scope) {
1034  return emitError(unknownLoc,
1035  "OpTypeJointMatrix references undefined scope <id> ")
1036  << operands[5];
1037  }
1038  auto matrixLayout =
1039  spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
1040  if (!matrixLayout) {
1041  return emitError(unknownLoc,
1042  "OpTypeJointMatrix references undefined scope <id> ")
1043  << operands[4];
1044  }
1045  unsigned rows = getConstantInt(operands[2]).getInt();
1046  unsigned columns = getConstantInt(operands[3]).getInt();
1047 
1048  typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
1049  elementTy, scope.value(), rows, columns, matrixLayout.value());
1050  return success();
1051 }
1052 
1053 LogicalResult
1054 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1055  if (operands.size() != 2) {
1056  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1057  }
1058  Type memberType = getType(operands[1]);
1059  if (!memberType) {
1060  return emitError(unknownLoc,
1061  "OpTypeRuntimeArray references undefined <id> ")
1062  << operands[1];
1063  }
1064  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1065  memberType, typeDecorations.lookup(operands[0]));
1066  return success();
1067 }
1068 
1069 LogicalResult
1070 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1071  // TODO: Find a way to handle identified structs when debug info is stripped.
1072 
1073  if (operands.empty()) {
1074  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1075  }
1076 
1077  if (operands.size() == 1) {
1078  // Handle empty struct.
1079  typeMap[operands[0]] =
1080  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1081  return success();
1082  }
1083 
1084  // First element is operand ID, second element is member index in the struct.
1085  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1086  SmallVector<Type, 4> memberTypes;
1087 
1088  for (auto op : llvm::drop_begin(operands, 1)) {
1089  Type memberType = getType(op);
1090  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1091 
1092  if (!memberType && !typeForwardPtr)
1093  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1094  << op;
1095 
1096  if (!memberType)
1097  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1098 
1099  memberTypes.push_back(memberType);
1100  }
1101 
1104  if (memberDecorationMap.count(operands[0])) {
1105  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1106  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1107  if (allMemberDecorations.count(memberIndex)) {
1108  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1109  // Check for offset.
1110  if (memberDecoration.first == spirv::Decoration::Offset) {
1111  // If offset info is empty, resize to the number of members;
1112  if (offsetInfo.empty()) {
1113  offsetInfo.resize(memberTypes.size());
1114  }
1115  offsetInfo[memberIndex] = memberDecoration.second[0];
1116  } else {
1117  if (!memberDecoration.second.empty()) {
1118  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1119  memberDecoration.first,
1120  memberDecoration.second[0]);
1121  } else {
1122  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1123  memberDecoration.first, 0);
1124  }
1125  }
1126  }
1127  }
1128  }
1129  }
1130 
1131  uint32_t structID = operands[0];
1132  std::string structIdentifier = nameMap.lookup(structID).str();
1133 
1134  if (structIdentifier.empty()) {
1135  assert(unresolvedMemberTypes.empty() &&
1136  "didn't expect unresolved member types");
1137  typeMap[structID] =
1138  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1139  } else {
1140  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1141  typeMap[structID] = structTy;
1142 
1143  if (!unresolvedMemberTypes.empty())
1144  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1145  memberTypes, offsetInfo,
1146  memberDecorationsInfo});
1147  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1148  memberDecorationsInfo)))
1149  return failure();
1150  }
1151 
1152  // TODO: Update StructType to have member name as attribute as
1153  // well.
1154  return success();
1155 }
1156 
1157 LogicalResult
1158 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1159  if (operands.size() != 3) {
1160  // Three operands are needed: result_id, column_type, and column_count
1161  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1162  " (result_id, column_type, and column_count)");
1163  }
1164  // Matrix columns must be of vector type
1165  Type elementTy = getType(operands[1]);
1166  if (!elementTy) {
1167  return emitError(unknownLoc,
1168  "OpTypeMatrix references undefined column type.")
1169  << operands[1];
1170  }
1171 
1172  uint32_t colsCount = operands[2];
1173  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1174  return success();
1175 }
1176 
1177 LogicalResult
1178 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1179  if (operands.size() != 2)
1180  return emitError(unknownLoc,
1181  "OpTypeForwardPointer instruction must have two operands");
1182 
1183  typeForwardPointerIDs.insert(operands[0]);
1184  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1185  // instruction that defines the actual type.
1186 
1187  return success();
1188 }
1189 
1190 LogicalResult
1191 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1192  // TODO: Add support for Access Qualifier.
1193  if (operands.size() != 8)
1194  return emitError(
1195  unknownLoc,
1196  "OpTypeImage with non-eight operands are not supported yet");
1197 
1198  Type elementTy = getType(operands[1]);
1199  if (!elementTy)
1200  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1201  << operands[1];
1202 
1203  auto dim = spirv::symbolizeDim(operands[2]);
1204  if (!dim)
1205  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1206  << operands[2];
1207 
1208  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1209  if (!depthInfo)
1210  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1211  << operands[3];
1212 
1213  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1214  if (!arrayedInfo)
1215  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1216  << operands[4];
1217 
1218  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1219  if (!samplingInfo)
1220  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1221 
1222  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1223  if (!samplerUseInfo)
1224  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1225  << operands[6];
1226 
1227  auto format = spirv::symbolizeImageFormat(operands[7]);
1228  if (!format)
1229  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1230  << operands[7];
1231 
1232  typeMap[operands[0]] = spirv::ImageType::get(
1233  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1234  samplingInfo.value(), samplerUseInfo.value(), format.value());
1235  return success();
1236 }
1237 
1238 LogicalResult
1239 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1240  if (operands.size() != 2)
1241  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1242 
1243  Type elementTy = getType(operands[1]);
1244  if (!elementTy)
1245  return emitError(unknownLoc,
1246  "OpTypeSampledImage references undefined <id>: ")
1247  << operands[1];
1248 
1249  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1250  return success();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // Constant
1255 //===----------------------------------------------------------------------===//
1256 
1257 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1258  bool isSpec) {
1259  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1260 
1261  if (operands.size() < 2) {
1262  return emitError(unknownLoc)
1263  << opname << " must have type <id> and result <id>";
1264  }
1265  if (operands.size() < 3) {
1266  return emitError(unknownLoc)
1267  << opname << " must have at least 1 more parameter";
1268  }
1269 
1270  Type resultType = getType(operands[0]);
1271  if (!resultType) {
1272  return emitError(unknownLoc, "undefined result type from <id> ")
1273  << operands[0];
1274  }
1275 
1276  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1277  if (bitwidth == 64) {
1278  if (operands.size() == 4) {
1279  return success();
1280  }
1281  return emitError(unknownLoc)
1282  << opname << " should have 2 parameters for 64-bit values";
1283  }
1284  if (bitwidth <= 32) {
1285  if (operands.size() == 3) {
1286  return success();
1287  }
1288 
1289  return emitError(unknownLoc)
1290  << opname
1291  << " should have 1 parameter for values with no more than 32 bits";
1292  }
1293  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1294  << bitwidth;
1295  };
1296 
1297  auto resultID = operands[1];
1298 
1299  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1300  auto bitwidth = intType.getWidth();
1301  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1302  return failure();
1303  }
1304 
1305  APInt value;
1306  if (bitwidth == 64) {
1307  // 64-bit integers are represented with two SPIR-V words. According to
1308  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1309  // literal’s low-order words appear first."
1310  struct DoubleWord {
1311  uint32_t word1;
1312  uint32_t word2;
1313  } words = {operands[2], operands[3]};
1314  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1315  } else if (bitwidth <= 32) {
1316  value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1317  }
1318 
1319  auto attr = opBuilder.getIntegerAttr(intType, value);
1320 
1321  if (isSpec) {
1322  createSpecConstant(unknownLoc, resultID, attr);
1323  } else {
1324  // For normal constants, we just record the attribute (and its type) for
1325  // later materialization at use sites.
1326  constantMap.try_emplace(resultID, attr, intType);
1327  }
1328 
1329  return success();
1330  }
1331 
1332  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1333  auto bitwidth = floatType.getWidth();
1334  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1335  return failure();
1336  }
1337 
1338  APFloat value(0.f);
1339  if (floatType.isF64()) {
1340  // Double values are represented with two SPIR-V words. According to
1341  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1342  // literal’s low-order words appear first."
1343  struct DoubleWord {
1344  uint32_t word1;
1345  uint32_t word2;
1346  } words = {operands[2], operands[3]};
1347  value = APFloat(llvm::bit_cast<double>(words));
1348  } else if (floatType.isF32()) {
1349  value = APFloat(llvm::bit_cast<float>(operands[2]));
1350  } else if (floatType.isF16()) {
1351  APInt data(16, operands[2]);
1352  value = APFloat(APFloat::IEEEhalf(), data);
1353  }
1354 
1355  auto attr = opBuilder.getFloatAttr(floatType, value);
1356  if (isSpec) {
1357  createSpecConstant(unknownLoc, resultID, attr);
1358  } else {
1359  // For normal constants, we just record the attribute (and its type) for
1360  // later materialization at use sites.
1361  constantMap.try_emplace(resultID, attr, floatType);
1362  }
1363 
1364  return success();
1365  }
1366 
1367  return emitError(unknownLoc, "OpConstant can only generate values of "
1368  "scalar integer or floating-point type");
1369 }
1370 
1371 LogicalResult spirv::Deserializer::processConstantBool(
1372  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1373  if (operands.size() != 2) {
1374  return emitError(unknownLoc, "Op")
1375  << (isSpec ? "Spec" : "") << "Constant"
1376  << (isTrue ? "True" : "False")
1377  << " must have type <id> and result <id>";
1378  }
1379 
1380  auto attr = opBuilder.getBoolAttr(isTrue);
1381  auto resultID = operands[1];
1382  if (isSpec) {
1383  createSpecConstant(unknownLoc, resultID, attr);
1384  } else {
1385  // For normal constants, we just record the attribute (and its type) for
1386  // later materialization at use sites.
1387  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1388  }
1389 
1390  return success();
1391 }
1392 
1393 LogicalResult
1394 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1395  if (operands.size() < 2) {
1396  return emitError(unknownLoc,
1397  "OpConstantComposite must have type <id> and result <id>");
1398  }
1399  if (operands.size() < 3) {
1400  return emitError(unknownLoc,
1401  "OpConstantComposite must have at least 1 parameter");
1402  }
1403 
1404  Type resultType = getType(operands[0]);
1405  if (!resultType) {
1406  return emitError(unknownLoc, "undefined result type from <id> ")
1407  << operands[0];
1408  }
1409 
1410  SmallVector<Attribute, 4> elements;
1411  elements.reserve(operands.size() - 2);
1412  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1413  auto elementInfo = getConstant(operands[i]);
1414  if (!elementInfo) {
1415  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1416  << operands[i] << " must come from a normal constant";
1417  }
1418  elements.push_back(elementInfo->first);
1419  }
1420 
1421  auto resultID = operands[1];
1422  if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1423  auto attr = DenseElementsAttr::get(vectorType, elements);
1424  // For normal constants, we just record the attribute (and its type) for
1425  // later materialization at use sites.
1426  constantMap.try_emplace(resultID, attr, resultType);
1427  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1428  auto attr = opBuilder.getArrayAttr(elements);
1429  constantMap.try_emplace(resultID, attr, resultType);
1430  } else {
1431  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1432  << resultType;
1433  }
1434 
1435  return success();
1436 }
1437 
1438 LogicalResult
1439 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1440  if (operands.size() < 2) {
1441  return emitError(unknownLoc,
1442  "OpConstantComposite must have type <id> and result <id>");
1443  }
1444  if (operands.size() < 3) {
1445  return emitError(unknownLoc,
1446  "OpConstantComposite must have at least 1 parameter");
1447  }
1448 
1449  Type resultType = getType(operands[0]);
1450  if (!resultType) {
1451  return emitError(unknownLoc, "undefined result type from <id> ")
1452  << operands[0];
1453  }
1454 
1455  auto resultID = operands[1];
1456  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1457 
1458  SmallVector<Attribute, 4> elements;
1459  elements.reserve(operands.size() - 2);
1460  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1461  auto elementInfo = getSpecConstant(operands[i]);
1462  elements.push_back(SymbolRefAttr::get(elementInfo));
1463  }
1464 
1465  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1466  unknownLoc, TypeAttr::get(resultType), symName,
1467  opBuilder.getArrayAttr(elements));
1468  specConstCompositeMap[resultID] = op;
1469 
1470  return success();
1471 }
1472 
1473 LogicalResult
1474 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1475  if (operands.size() < 3)
1476  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1477  "result <id>, and operand opcode");
1478 
1479  uint32_t resultTypeID = operands[0];
1480 
1481  if (!getType(resultTypeID))
1482  return emitError(unknownLoc, "undefined result type from <id> ")
1483  << resultTypeID;
1484 
1485  uint32_t resultID = operands[1];
1486  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1487  auto emplaceResult = specConstOperationMap.try_emplace(
1488  resultID,
1489  SpecConstOperationMaterializationInfo{
1490  enclosedOpcode, resultTypeID,
1491  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1492 
1493  if (!emplaceResult.second)
1494  return emitError(unknownLoc, "value with <id>: ")
1495  << resultID << " is probably defined before.";
1496 
1497  return success();
1498 }
1499 
1500 Value spirv::Deserializer::materializeSpecConstantOperation(
1501  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1502  ArrayRef<uint32_t> enclosedOpOperands) {
1503 
1504  Type resultType = getType(resultTypeID);
1505 
1506  // Instructions wrapped by OpSpecConstantOp need an ID for their
1507  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1508  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1509  // ID in that map is assigned to the result of the enclosed instruction. Note
1510  // that there is no need to update this fake ID since we only need to
1511  // reference the created Value for the enclosed op from the spv::YieldOp
1512  // created later in this method (both of which are the only values in their
1513  // region: the SpecConstantOperation's region). If we encounter another
1514  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1515  // previous Value assigned to it isn't visible in the current scope anyway.
1516  DenseMap<uint32_t, Value> newValueMap;
1517  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1518  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1519 
1520  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1521  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1522  enclosedOpResultTypeAndOperands.push_back(fakeID);
1523  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1524  enclosedOpOperands.end());
1525 
1526  // Process enclosed instruction before creating the enclosing
1527  // specConstantOperation (and its region). This way, references to constants,
1528  // global variables, and spec constants will be materialized outside the new
1529  // op's region. For more info, see Deserializer::getValue's implementation.
1530  if (failed(
1531  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1532  return Value();
1533 
1534  // Since the enclosed op is emitted in the current block, split it in a
1535  // separate new block.
1536  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1537 
1538  auto loc = createFileLineColLoc(opBuilder);
1539  auto specConstOperationOp =
1540  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1541 
1542  Region &body = specConstOperationOp.getBody();
1543  // Move the new block into SpecConstantOperation's body.
1544  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1545  Region::iterator(enclosedBlock));
1546  Block &block = body.back();
1547 
1548  // RAII guard to reset the insertion point to the module's region after
1549  // deserializing the body of the specConstantOperation.
1550  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1551  opBuilder.setInsertionPointToEnd(&block);
1552 
1553  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1554  return specConstOperationOp.getResult();
1555 }
1556 
1557 LogicalResult
1558 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1559  if (operands.size() != 2) {
1560  return emitError(unknownLoc,
1561  "OpConstantNull must have type <id> and result <id>");
1562  }
1563 
1564  Type resultType = getType(operands[0]);
1565  if (!resultType) {
1566  return emitError(unknownLoc, "undefined result type from <id> ")
1567  << operands[0];
1568  }
1569 
1570  auto resultID = operands[1];
1571  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1572  auto attr = opBuilder.getZeroAttr(resultType);
1573  // For normal constants, we just record the attribute (and its type) for
1574  // later materialization at use sites.
1575  constantMap.try_emplace(resultID, attr, resultType);
1576  return success();
1577  }
1578 
1579  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1580  << resultType;
1581 }
1582 
1583 //===----------------------------------------------------------------------===//
1584 // Control flow
1585 //===----------------------------------------------------------------------===//
1586 
1587 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1588  if (auto *block = getBlock(id)) {
1589  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1590  << " @ " << block << "\n");
1591  return block;
1592  }
1593 
1594  // We don't know where this block will be placed finally (in a
1595  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1596  // function for now and sort out the proper place later.
1597  auto *block = curFunction->addBlock();
1598  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1599  << " @ " << block << "\n");
1600  return blockMap[id] = block;
1601 }
1602 
1603 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1604  if (!curBlock) {
1605  return emitError(unknownLoc, "OpBranch must appear inside a block");
1606  }
1607 
1608  if (operands.size() != 1) {
1609  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1610  }
1611 
1612  auto *target = getOrCreateBlock(operands[0]);
1613  auto loc = createFileLineColLoc(opBuilder);
1614  // The preceding instruction for the OpBranch instruction could be an
1615  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1616  // the same OpLine information.
1617  opBuilder.create<spirv::BranchOp>(loc, target);
1618 
1619  clearDebugLine();
1620  return success();
1621 }
1622 
1623 LogicalResult
1624 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1625  if (!curBlock) {
1626  return emitError(unknownLoc,
1627  "OpBranchConditional must appear inside a block");
1628  }
1629 
1630  if (operands.size() != 3 && operands.size() != 5) {
1631  return emitError(unknownLoc,
1632  "OpBranchConditional must have condition, true label, "
1633  "false label, and optionally two branch weights");
1634  }
1635 
1636  auto condition = getValue(operands[0]);
1637  auto *trueBlock = getOrCreateBlock(operands[1]);
1638  auto *falseBlock = getOrCreateBlock(operands[2]);
1639 
1640  std::optional<std::pair<uint32_t, uint32_t>> weights;
1641  if (operands.size() == 5) {
1642  weights = std::make_pair(operands[3], operands[4]);
1643  }
1644  // The preceding instruction for the OpBranchConditional instruction could be
1645  // an OpSelectionMerge instruction, in this case they will have the same
1646  // OpLine information.
1647  auto loc = createFileLineColLoc(opBuilder);
1648  opBuilder.create<spirv::BranchConditionalOp>(
1649  loc, condition, trueBlock,
1650  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1651  /*falseArguments=*/ArrayRef<Value>(), weights);
1652 
1653  clearDebugLine();
1654  return success();
1655 }
1656 
1657 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1658  if (!curFunction) {
1659  return emitError(unknownLoc, "OpLabel must appear inside a function");
1660  }
1661 
1662  if (operands.size() != 1) {
1663  return emitError(unknownLoc, "OpLabel should only have result <id>");
1664  }
1665 
1666  auto labelID = operands[0];
1667  // We may have forward declared this block.
1668  auto *block = getOrCreateBlock(labelID);
1669  LLVM_DEBUG(logger.startLine()
1670  << "[block] populating block " << block << "\n");
1671  // If we have seen this block, make sure it was just a forward declaration.
1672  assert(block->empty() && "re-deserialize the same block!");
1673 
1674  opBuilder.setInsertionPointToStart(block);
1675  blockMap[labelID] = curBlock = block;
1676 
1677  return success();
1678 }
1679 
1680 LogicalResult
1681 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1682  if (!curBlock) {
1683  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1684  }
1685 
1686  if (operands.size() < 2) {
1687  return emitError(
1688  unknownLoc,
1689  "OpSelectionMerge must specify merge target and selection control");
1690  }
1691 
1692  auto *mergeBlock = getOrCreateBlock(operands[0]);
1693  auto loc = createFileLineColLoc(opBuilder);
1694  auto selectionControl = operands[1];
1695 
1696  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1697  .second) {
1698  return emitError(
1699  unknownLoc,
1700  "a block cannot have more than one OpSelectionMerge instruction");
1701  }
1702 
1703  return success();
1704 }
1705 
1706 LogicalResult
1707 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1708  if (!curBlock) {
1709  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1710  }
1711 
1712  if (operands.size() < 3) {
1713  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1714  "continue target and loop control");
1715  }
1716 
1717  auto *mergeBlock = getOrCreateBlock(operands[0]);
1718  auto *continueBlock = getOrCreateBlock(operands[1]);
1719  auto loc = createFileLineColLoc(opBuilder);
1720  uint32_t loopControl = operands[2];
1721 
1722  if (!blockMergeInfo
1723  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1724  .second) {
1725  return emitError(
1726  unknownLoc,
1727  "a block cannot have more than one OpLoopMerge instruction");
1728  }
1729 
1730  return success();
1731 }
1732 
1733 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1734  if (!curBlock) {
1735  return emitError(unknownLoc, "OpPhi must appear in a block");
1736  }
1737 
1738  if (operands.size() < 4) {
1739  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1740  "and variable-parent pairs");
1741  }
1742 
1743  // Create a block argument for this OpPhi instruction.
1744  Type blockArgType = getType(operands[0]);
1745  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1746  valueMap[operands[1]] = blockArg;
1747  LLVM_DEBUG(logger.startLine()
1748  << "[phi] created block argument " << blockArg
1749  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1750 
1751  // For each (value, predecessor) pair, insert the value to the predecessor's
1752  // blockPhiInfo entry so later we can fix the block argument there.
1753  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1754  uint32_t value = operands[i];
1755  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1756  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1757  blockPhiInfo[predecessorTargetPair].push_back(value);
1758  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1759  << " with arg id = " << value << "\n");
1760  }
1761 
1762  return success();
1763 }
1764 
1765 namespace {
1766 /// A class for putting all blocks in a structured selection/loop in a
1767 /// spirv.mlir.selection/spirv.mlir.loop op.
1768 class ControlFlowStructurizer {
1769 public:
1770 #ifndef NDEBUG
1771  ControlFlowStructurizer(Location loc, uint32_t control,
1772  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1773  Block *merge, Block *cont,
1774  llvm::ScopedPrinter &logger)
1775  : location(loc), control(control), blockMergeInfo(mergeInfo),
1776  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1777  logger(logger) {}
1778 #else
1779  ControlFlowStructurizer(Location loc, uint32_t control,
1780  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1781  Block *merge, Block *cont)
1782  : location(loc), control(control), blockMergeInfo(mergeInfo),
1783  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1784 #endif
1785 
1786  /// Structurizes the loop at the given `headerBlock`.
1787  ///
1788  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1789  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1790  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1791  /// method will also update `mergeInfo` by remapping all blocks inside to the
1792  /// newly cloned ones inside structured control flow op's regions.
1793  LogicalResult structurize();
1794 
1795 private:
1796  /// Creates a new spirv.mlir.selection op at the beginning of the
1797  /// `mergeBlock`.
1798  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1799 
1800  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1801  spirv::LoopOp createLoopOp(uint32_t loopControl);
1802 
1803  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1804  void collectBlocksInConstruct();
1805 
1806  Location location;
1807  uint32_t control;
1808 
1809  spirv::BlockMergeInfoMap &blockMergeInfo;
1810 
1811  Block *headerBlock;
1812  Block *mergeBlock;
1813  Block *continueBlock; // nullptr for spirv.mlir.selection
1814 
1815  SetVector<Block *> constructBlocks;
1816 
1817 #ifndef NDEBUG
1818  /// A logger used to emit information during the deserialzation process.
1819  llvm::ScopedPrinter &logger;
1820 #endif
1821 };
1822 } // namespace
1823 
1824 spirv::SelectionOp
1825 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1826  // Create a builder and set the insertion point to the beginning of the
1827  // merge block so that the newly created SelectionOp will be inserted there.
1828  OpBuilder builder(&mergeBlock->front());
1829 
1830  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1831  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1832  selectionOp.addMergeBlock(builder);
1833 
1834  return selectionOp;
1835 }
1836 
1837 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1838  // Create a builder and set the insertion point to the beginning of the
1839  // merge block so that the newly created LoopOp will be inserted there.
1840  OpBuilder builder(&mergeBlock->front());
1841 
1842  auto control = static_cast<spirv::LoopControl>(loopControl);
1843  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1844  loopOp.addEntryAndMergeBlock(builder);
1845 
1846  return loopOp;
1847 }
1848 
1849 void ControlFlowStructurizer::collectBlocksInConstruct() {
1850  assert(constructBlocks.empty() && "expected empty constructBlocks");
1851 
1852  // Put the header block in the work list first.
1853  constructBlocks.insert(headerBlock);
1854 
1855  // For each item in the work list, add its successors excluding the merge
1856  // block.
1857  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1858  for (auto *successor : constructBlocks[i]->getSuccessors())
1859  if (successor != mergeBlock)
1860  constructBlocks.insert(successor);
1861  }
1862 }
1863 
1864 LogicalResult ControlFlowStructurizer::structurize() {
1865  Operation *op = nullptr;
1866  bool isLoop = continueBlock != nullptr;
1867  if (isLoop) {
1868  if (auto loopOp = createLoopOp(control))
1869  op = loopOp.getOperation();
1870  } else {
1871  if (auto selectionOp = createSelectionOp(control))
1872  op = selectionOp.getOperation();
1873  }
1874  if (!op)
1875  return failure();
1876  Region &body = op->getRegion(0);
1877 
1878  IRMapping mapper;
1879  // All references to the old merge block should be directed to the
1880  // selection/loop merge block in the SelectionOp/LoopOp's region.
1881  mapper.map(mergeBlock, &body.back());
1882 
1883  collectBlocksInConstruct();
1884 
1885  // We've identified all blocks belonging to the selection/loop's region. Now
1886  // need to "move" them into the selection/loop. Instead of really moving the
1887  // blocks, in the following we copy them and remap all values and branches.
1888  // This is because:
1889  // * Inserting a block into a region requires the block not in any region
1890  // before. But selections/loops can nest so we can create selection/loop ops
1891  // in a nested manner, which means some blocks may already be in a
1892  // selection/loop region when to be moved again.
1893  // * It's much trickier to fix up the branches into and out of the loop's
1894  // region: we need to treat not-moved blocks and moved blocks differently:
1895  // Not-moved blocks jumping to the loop header block need to jump to the
1896  // merge point containing the new loop op but not the loop continue block's
1897  // back edge. Moved blocks jumping out of the loop need to jump to the
1898  // merge block inside the loop region but not other not-moved blocks.
1899  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1900  // logic.
1901 
1902  // Create a corresponding block in the SelectionOp/LoopOp's region for each
1903  // block in this loop construct.
1904  OpBuilder builder(body);
1905  for (auto *block : constructBlocks) {
1906  // Create a block and insert it before the selection/loop merge block in the
1907  // SelectionOp/LoopOp's region.
1908  auto *newBlock = builder.createBlock(&body.back());
1909  mapper.map(block, newBlock);
1910  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1911  << " from block " << block << "\n");
1912  if (!isFnEntryBlock(block)) {
1913  for (BlockArgument blockArg : block->getArguments()) {
1914  auto newArg =
1915  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1916  mapper.map(blockArg, newArg);
1917  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1918  << blockArg << " to " << newArg << "\n");
1919  }
1920  } else {
1921  LLVM_DEBUG(logger.startLine()
1922  << "[cf] block " << block << " is a function entry block\n");
1923  }
1924 
1925  for (auto &op : *block)
1926  newBlock->push_back(op.clone(mapper));
1927  }
1928 
1929  // Go through all ops and remap the operands.
1930  auto remapOperands = [&](Operation *op) {
1931  for (auto &operand : op->getOpOperands())
1932  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1933  operand.set(mappedOp);
1934  for (auto &succOp : op->getBlockOperands())
1935  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1936  succOp.set(mappedOp);
1937  };
1938  for (auto &block : body)
1939  block.walk(remapOperands);
1940 
1941  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1942  // the selection/loop construct into its region. Next we need to fix the
1943  // connections between this new SelectionOp/LoopOp with existing blocks.
1944 
1945  // All existing incoming branches should go to the merge block, where the
1946  // SelectionOp/LoopOp resides right now.
1947  headerBlock->replaceAllUsesWith(mergeBlock);
1948 
1949  LLVM_DEBUG({
1950  logger.startLine() << "[cf] after cloning and fixing references:\n";
1951  headerBlock->getParentOp()->print(logger.getOStream());
1952  logger.startLine() << "\n";
1953  });
1954 
1955  if (isLoop) {
1956  if (!mergeBlock->args_empty()) {
1957  return mergeBlock->getParentOp()->emitError(
1958  "OpPhi in loop merge block unsupported");
1959  }
1960 
1961  // The loop header block may have block arguments. Since now we place the
1962  // loop op inside the old merge block, we need to make sure the old merge
1963  // block has the same block argument list.
1964  for (BlockArgument blockArg : headerBlock->getArguments())
1965  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1966 
1967  // If the loop header block has block arguments, make sure the spirv.Branch
1968  // op matches.
1969  SmallVector<Value, 4> blockArgs;
1970  if (!headerBlock->args_empty())
1971  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1972 
1973  // The loop entry block should have a unconditional branch jumping to the
1974  // loop header block.
1975  builder.setInsertionPointToEnd(&body.front());
1976  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1977  ArrayRef<Value>(blockArgs));
1978  }
1979 
1980  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1981  // cleaned up.
1982  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1983  // First we need to drop all operands' references inside all blocks. This is
1984  // needed because we can have blocks referencing SSA values from one another.
1985  for (auto *block : constructBlocks)
1986  block->dropAllReferences();
1987 
1988  // Check that whether some op in the to-be-erased blocks still has uses. Those
1989  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1990  // region. We cannot handle such cases given that once a value is sinked into
1991  // the SelectionOp/LoopOp's region, there is no escape for it:
1992  // SelectionOp/LooOp does not support yield values right now.
1993  for (auto *block : constructBlocks) {
1994  for (Operation &op : *block)
1995  if (!op.use_empty())
1996  return op.emitOpError(
1997  "failed control flow structurization: it has uses outside of the "
1998  "enclosing selection/loop construct");
1999  }
2000 
2001  // Then erase all old blocks.
2002  for (auto *block : constructBlocks) {
2003  // We've cloned all blocks belonging to this construct into the structured
2004  // control flow op's region. Among these blocks, some may compose another
2005  // selection/loop. If so, they will be recorded within blockMergeInfo.
2006  // We need to update the pointers there to the newly remapped ones so we can
2007  // continue structurizing them later.
2008  // TODO: The asserts in the following assumes input SPIR-V blob forms
2009  // correctly nested selection/loop constructs. We should relax this and
2010  // support error cases better.
2011  auto it = blockMergeInfo.find(block);
2012  if (it != blockMergeInfo.end()) {
2013  // Use the original location for nested selection/loop ops.
2014  Location loc = it->second.loc;
2015 
2016  Block *newHeader = mapper.lookupOrNull(block);
2017  if (!newHeader)
2018  return emitError(loc, "failed control flow structurization: nested "
2019  "loop header block should be remapped!");
2020 
2021  Block *newContinue = it->second.continueBlock;
2022  if (newContinue) {
2023  newContinue = mapper.lookupOrNull(newContinue);
2024  if (!newContinue)
2025  return emitError(loc, "failed control flow structurization: nested "
2026  "loop continue block should be remapped!");
2027  }
2028 
2029  Block *newMerge = it->second.mergeBlock;
2030  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2031  newMerge = mappedTo;
2032 
2033  // The iterator should be erased before adding a new entry into
2034  // blockMergeInfo to avoid iterator invalidation.
2035  blockMergeInfo.erase(it);
2036  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2037  newContinue);
2038  }
2039 
2040  // The structured selection/loop's entry block does not have arguments.
2041  // If the function's header block is also part of the structured control
2042  // flow, we cannot just simply erase it because it may contain arguments
2043  // matching the function signature and used by the cloned blocks.
2044  if (isFnEntryBlock(block)) {
2045  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2046  << " to only contain a spirv.Branch op\n");
2047  // Still keep the function entry block for the potential block arguments,
2048  // but replace all ops inside with a branch to the merge block.
2049  block->clear();
2050  builder.setInsertionPointToEnd(block);
2051  builder.create<spirv::BranchOp>(location, mergeBlock);
2052  } else {
2053  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2054  block->erase();
2055  }
2056  }
2057 
2058  LLVM_DEBUG(logger.startLine()
2059  << "[cf] after structurizing construct with header block "
2060  << headerBlock << ":\n"
2061  << *op << "\n");
2062 
2063  return success();
2064 }
2065 
2066 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2067  LLVM_DEBUG({
2068  logger.startLine()
2069  << "//----- [phi] start wiring up block arguments -----//\n";
2070  logger.indent();
2071  });
2072 
2073  OpBuilder::InsertionGuard guard(opBuilder);
2074 
2075  for (const auto &info : blockPhiInfo) {
2076  Block *block = info.first.first;
2077  Block *target = info.first.second;
2078  const BlockPhiInfo &phiInfo = info.second;
2079  LLVM_DEBUG({
2080  logger.startLine() << "[phi] block " << block << "\n";
2081  logger.startLine() << "[phi] before creating block argument:\n";
2082  block->getParentOp()->print(logger.getOStream());
2083  logger.startLine() << "\n";
2084  });
2085 
2086  // Set insertion point to before this block's terminator early because we
2087  // may materialize ops via getValue() call.
2088  auto *op = block->getTerminator();
2089  opBuilder.setInsertionPoint(op);
2090 
2091  SmallVector<Value, 4> blockArgs;
2092  blockArgs.reserve(phiInfo.size());
2093  for (uint32_t valueId : phiInfo) {
2094  if (Value value = getValue(valueId)) {
2095  blockArgs.push_back(value);
2096  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2097  << " id = " << valueId << "\n");
2098  } else {
2099  return emitError(unknownLoc, "OpPhi references undefined value!");
2100  }
2101  }
2102 
2103  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2104  // Replace the previous branch op with a new one with block arguments.
2105  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2106  blockArgs);
2107  branchOp.erase();
2108  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2109  assert((branchCondOp.getTrueBlock() == target ||
2110  branchCondOp.getFalseBlock() == target) &&
2111  "expected target to be either the true or false target");
2112  if (target == branchCondOp.getTrueTarget())
2113  opBuilder.create<spirv::BranchConditionalOp>(
2114  branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2115  branchCondOp.getFalseBlockArguments(),
2116  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2117  branchCondOp.getFalseTarget());
2118  else
2119  opBuilder.create<spirv::BranchConditionalOp>(
2120  branchCondOp.getLoc(), branchCondOp.getCondition(),
2121  branchCondOp.getTrueBlockArguments(), blockArgs,
2122  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2123  branchCondOp.getFalseBlock());
2124 
2125  branchCondOp.erase();
2126  } else {
2127  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2128  }
2129 
2130  LLVM_DEBUG({
2131  logger.startLine() << "[phi] after creating block argument:\n";
2132  block->getParentOp()->print(logger.getOStream());
2133  logger.startLine() << "\n";
2134  });
2135  }
2136  blockPhiInfo.clear();
2137 
2138  LLVM_DEBUG({
2139  logger.unindent();
2140  logger.startLine()
2141  << "//--- [phi] completed wiring up block arguments ---//\n";
2142  });
2143  return success();
2144 }
2145 
2146 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2147  LLVM_DEBUG({
2148  logger.startLine()
2149  << "//----- [cf] start structurizing control flow -----//\n";
2150  logger.indent();
2151  });
2152 
2153  while (!blockMergeInfo.empty()) {
2154  Block *headerBlock = blockMergeInfo.begin()->first;
2155  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2156 
2157  LLVM_DEBUG({
2158  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2159  headerBlock->print(logger.getOStream());
2160  logger.startLine() << "\n";
2161  });
2162 
2163  auto *mergeBlock = mergeInfo.mergeBlock;
2164  assert(mergeBlock && "merge block cannot be nullptr");
2165  if (!mergeBlock->args_empty())
2166  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2167  LLVM_DEBUG({
2168  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2169  mergeBlock->print(logger.getOStream());
2170  logger.startLine() << "\n";
2171  });
2172 
2173  auto *continueBlock = mergeInfo.continueBlock;
2174  LLVM_DEBUG(if (continueBlock) {
2175  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2176  continueBlock->print(logger.getOStream());
2177  logger.startLine() << "\n";
2178  });
2179  // Erase this case before calling into structurizer, who will update
2180  // blockMergeInfo.
2181  blockMergeInfo.erase(blockMergeInfo.begin());
2182  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2183  blockMergeInfo, headerBlock,
2184  mergeBlock, continueBlock
2185 #ifndef NDEBUG
2186  ,
2187  logger
2188 #endif
2189  );
2190  if (failed(structurizer.structurize()))
2191  return failure();
2192  }
2193 
2194  LLVM_DEBUG({
2195  logger.unindent();
2196  logger.startLine()
2197  << "//--- [cf] completed structurizing control flow ---//\n";
2198  });
2199  return success();
2200 }
2201 
2202 //===----------------------------------------------------------------------===//
2203 // Debug
2204 //===----------------------------------------------------------------------===//
2205 
2206 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2207  if (!debugLine)
2208  return unknownLoc;
2209 
2210  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2211  if (fileName.empty())
2212  fileName = "<unknown>";
2213  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2214  debugLine->column);
2215 }
2216 
2217 LogicalResult
2218 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2219  // According to SPIR-V spec:
2220  // "This location information applies to the instructions physically
2221  // following this instruction, up to the first occurrence of any of the
2222  // following: the next end of block, the next OpLine instruction, or the next
2223  // OpNoLine instruction."
2224  if (operands.size() != 3)
2225  return emitError(unknownLoc, "OpLine must have 3 operands");
2226  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2227  return success();
2228 }
2229 
2230 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2231 
2232 LogicalResult
2233 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2234  if (operands.size() < 2)
2235  return emitError(unknownLoc, "OpString needs at least 2 operands");
2236 
2237  if (!debugInfoMap.lookup(operands[0]).empty())
2238  return emitError(unknownLoc,
2239  "duplicate debug string found for result <id> ")
2240  << operands[0];
2241 
2242  unsigned wordIndex = 1;
2243  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2244  if (wordIndex != operands.size())
2245  return emitError(unknownLoc,
2246  "unexpected trailing words in OpString instruction");
2247 
2248  debugInfoMap[operands[0]] = debugString;
2249  return success();
2250 }
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
int64_t rows
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
bool empty()
Definition: Block.h:146
Operation & back()
Definition: Block.h:150
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:65
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:307
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
void print(raw_ostream &os)
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
void push_back(Operation *op)
Definition: Block.h:147
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:273
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:691
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
BlockListType::iterator iterator
Definition: Region.h:52
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:120
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:228
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:169
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:298
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:538
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:807
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.