MLIR  19.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 using namespace mlir;
35 
36 #define DEBUG_TYPE "spirv-deserialization"
37 
38 //===----------------------------------------------------------------------===//
39 // Utility Functions
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns true if the given `block` is a function entry block.
43 static inline bool isFnEntryBlock(Block *block) {
44  return block->isEntryBlock() &&
45  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // Deserializer Method Definitions
50 //===----------------------------------------------------------------------===//
51 
53  MLIRContext *context)
54  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55  module(createModuleOp()), opBuilder(module->getRegion())
56 #ifndef NDEBUG
57  ,
58  logger(llvm::dbgs())
59 #endif
60 {
61 }
62 
64  LLVM_DEBUG({
65  logger.resetIndent();
66  logger.startLine()
67  << "//+++---------- start deserialization ----------+++//\n";
68  });
69 
70  if (failed(processHeader()))
71  return failure();
72 
73  spirv::Opcode opcode = spirv::Opcode::OpNop;
74  ArrayRef<uint32_t> operands;
75  auto binarySize = binary.size();
76  while (curOffset < binarySize) {
77  // Slice the next instruction out and populate `opcode` and `operands`.
78  // Internally this also updates `curOffset`.
79  if (failed(sliceInstruction(opcode, operands)))
80  return failure();
81 
82  if (failed(processInstruction(opcode, operands)))
83  return failure();
84  }
85 
86  assert(curOffset == binarySize &&
87  "deserializer should never index beyond the binary end");
88 
89  for (auto &deferred : deferredInstructions) {
90  if (failed(processInstruction(deferred.first, deferred.second, false))) {
91  return failure();
92  }
93  }
94 
95  attachVCETriple();
96 
97  LLVM_DEBUG(logger.startLine()
98  << "//+++-------- completed deserialization --------+++//\n");
99  return success();
100 }
101 
103  return std::move(module);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Module structure
108 //===----------------------------------------------------------------------===//
109 
110 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111  OpBuilder builder(context);
112  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113  spirv::ModuleOp::build(builder, state);
114  return cast<spirv::ModuleOp>(Operation::create(state));
115 }
116 
117 LogicalResult spirv::Deserializer::processHeader() {
118  if (binary.size() < spirv::kHeaderWordCount)
119  return emitError(unknownLoc,
120  "SPIR-V binary module must have a 5-word header");
121 
122  if (binary[0] != spirv::kMagicNumber)
123  return emitError(unknownLoc, "incorrect magic number");
124 
125  // Version number bytes: 0 | major number | minor number | 0
126  uint32_t majorVersion = (binary[1] << 8) >> 24;
127  uint32_t minorVersion = (binary[1] << 16) >> 24;
128  if (majorVersion == 1) {
129  switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
131  case v: \
132  version = spirv::Version::V_1_##v; \
133  break
134 
135  MIN_VERSION_CASE(0);
136  MIN_VERSION_CASE(1);
137  MIN_VERSION_CASE(2);
138  MIN_VERSION_CASE(3);
139  MIN_VERSION_CASE(4);
140  MIN_VERSION_CASE(5);
141 #undef MIN_VERSION_CASE
142  default:
143  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
144  << minorVersion;
145  }
146  } else {
147  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
148  << majorVersion;
149  }
150 
151  // TODO: generator number, bound, schema
152  curOffset = spirv::kHeaderWordCount;
153  return success();
154 }
155 
157 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158  if (operands.size() != 1)
159  return emitError(unknownLoc, "OpMemoryModel must have one parameter");
160 
161  auto cap = spirv::symbolizeCapability(operands[0]);
162  if (!cap)
163  return emitError(unknownLoc, "unknown capability: ") << operands[0];
164 
165  capabilities.insert(*cap);
166  return success();
167 }
168 
169 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170  if (words.empty()) {
171  return emitError(
172  unknownLoc,
173  "OpExtension must have a literal string for the extension name");
174  }
175 
176  unsigned wordIndex = 0;
177  StringRef extName = decodeStringLiteral(words, wordIndex);
178  if (wordIndex != words.size())
179  return emitError(unknownLoc,
180  "unexpected trailing words in OpExtension instruction");
181  auto ext = spirv::symbolizeExtension(extName);
182  if (!ext)
183  return emitError(unknownLoc, "unknown extension: ") << extName;
184 
185  extensions.insert(*ext);
186  return success();
187 }
188 
190 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191  if (words.size() < 2) {
192  return emitError(unknownLoc,
193  "OpExtInstImport must have a result <id> and a literal "
194  "string for the extended instruction set name");
195  }
196 
197  unsigned wordIndex = 1;
198  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199  if (wordIndex != words.size()) {
200  return emitError(unknownLoc,
201  "unexpected trailing words in OpExtInstImport");
202  }
203  return success();
204 }
205 
206 void spirv::Deserializer::attachVCETriple() {
207  (*module)->setAttr(
208  spirv::ModuleOp::getVCETripleAttrName(),
209  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
210  extensions.getArrayRef(), context));
211 }
212 
214 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215  if (operands.size() != 2)
216  return emitError(unknownLoc, "OpMemoryModel must have two operands");
217 
218  (*module)->setAttr(
219  module->getAddressingModelAttrName(),
220  opBuilder.getAttr<spirv::AddressingModelAttr>(
221  static_cast<spirv::AddressingModel>(operands.front())));
222 
223  (*module)->setAttr(module->getMemoryModelAttrName(),
224  opBuilder.getAttr<spirv::MemoryModelAttr>(
225  static_cast<spirv::MemoryModel>(operands.back())));
226 
227  return success();
228 }
229 
230 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
231  // TODO: This function should also be auto-generated. For now, since only a
232  // few decorations are processed/handled in a meaningful manner, going with a
233  // manual implementation.
234  if (words.size() < 2) {
235  return emitError(
236  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
237  }
238  auto decorationName =
239  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
240  if (decorationName.empty()) {
241  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
242  }
243  auto symbol = getSymbolDecoration(decorationName);
244  switch (static_cast<spirv::Decoration>(words[1])) {
245  case spirv::Decoration::FPFastMathMode:
246  if (words.size() != 3) {
247  return emitError(unknownLoc, "OpDecorate with ")
248  << decorationName << " needs a single integer literal";
249  }
250  decorations[words[0]].set(
251  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
252  static_cast<FPFastMathMode>(words[2])));
253  break;
254  case spirv::Decoration::DescriptorSet:
255  case spirv::Decoration::Binding:
256  if (words.size() != 3) {
257  return emitError(unknownLoc, "OpDecorate with ")
258  << decorationName << " needs a single integer literal";
259  }
260  decorations[words[0]].set(
261  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
262  break;
263  case spirv::Decoration::BuiltIn:
264  if (words.size() != 3) {
265  return emitError(unknownLoc, "OpDecorate with ")
266  << decorationName << " needs a single integer literal";
267  }
268  decorations[words[0]].set(
269  symbol, opBuilder.getStringAttr(
270  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
271  break;
272  case spirv::Decoration::ArrayStride:
273  if (words.size() != 3) {
274  return emitError(unknownLoc, "OpDecorate with ")
275  << decorationName << " needs a single integer literal";
276  }
277  typeDecorations[words[0]] = words[2];
278  break;
279  case spirv::Decoration::LinkageAttributes: {
280  if (words.size() < 4) {
281  return emitError(unknownLoc, "OpDecorate with ")
282  << decorationName
283  << " needs at least 1 string and 1 integer literal";
284  }
285  // LinkageAttributes has two parameters ["linkageName", linkageType]
286  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
287  // "linkageName" is a stringliteral encoded as uint32_t,
288  // hence the size of name is variable length which results in words.size()
289  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
290  // 3 + ceildiv(strlen(name), 4).
291  unsigned wordIndex = 2;
292  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
293  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
294  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
295  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
296  StringAttr::get(context, linkageName), linkageTypeAttr);
297  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
298  break;
299  }
300  case spirv::Decoration::Aliased:
301  case spirv::Decoration::AliasedPointer:
302  case spirv::Decoration::Block:
303  case spirv::Decoration::BufferBlock:
304  case spirv::Decoration::Flat:
305  case spirv::Decoration::NonReadable:
306  case spirv::Decoration::NonWritable:
307  case spirv::Decoration::NoPerspective:
308  case spirv::Decoration::NoSignedWrap:
309  case spirv::Decoration::NoUnsignedWrap:
310  case spirv::Decoration::RelaxedPrecision:
311  case spirv::Decoration::Restrict:
312  case spirv::Decoration::RestrictPointer:
313  case spirv::Decoration::NoContraction:
314  if (words.size() != 2) {
315  return emitError(unknownLoc, "OpDecoration with ")
316  << decorationName << "needs a single target <id>";
317  }
318  // Block decoration does not affect spirv.struct type, but is still stored
319  // for verification.
320  // TODO: Update StructType to contain this information since
321  // it is needed for many validation rules.
322  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
323  break;
324  case spirv::Decoration::Location:
325  case spirv::Decoration::SpecId:
326  if (words.size() != 3) {
327  return emitError(unknownLoc, "OpDecoration with ")
328  << decorationName << "needs a single integer literal";
329  }
330  decorations[words[0]].set(
331  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
332  break;
333  default:
334  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
335  }
336  return success();
337 }
338 
340 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
341  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
342  if (words.size() < 3) {
343  return emitError(unknownLoc,
344  "OpMemberDecorate must have at least 3 operands");
345  }
346 
347  auto decoration = static_cast<spirv::Decoration>(words[2]);
348  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
349  return emitError(unknownLoc,
350  " missing offset specification in OpMemberDecorate with "
351  "Offset decoration");
352  }
353  ArrayRef<uint32_t> decorationOperands;
354  if (words.size() > 3) {
355  decorationOperands = words.slice(3);
356  }
357  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
358  return success();
359 }
360 
361 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
362  if (words.size() < 3) {
363  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
364  }
365  unsigned wordIndex = 2;
366  auto name = decodeStringLiteral(words, wordIndex);
367  if (wordIndex != words.size()) {
368  return emitError(unknownLoc,
369  "unexpected trailing words in OpMemberName instruction");
370  }
371  memberNameMap[words[0]][words[1]] = name;
372  return success();
373 }
374 
375 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
376  uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
377  if (!decorations.contains(argID)) {
378  argAttrs[argIndex] = DictionaryAttr::get(context, {});
379  return success();
380  }
381 
382  spirv::DecorationAttr foundDecorationAttr;
383  for (NamedAttribute decAttr : decorations[argID]) {
384  for (auto decoration :
385  {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
386  spirv::Decoration::AliasedPointer,
387  spirv::Decoration::RestrictPointer}) {
388 
389  if (decAttr.getName() !=
390  getSymbolDecoration(stringifyDecoration(decoration)))
391  continue;
392 
393  if (foundDecorationAttr)
394  return emitError(unknownLoc,
395  "more than one Aliased/Restrict decorations for "
396  "function argument with result <id> ")
397  << argID;
398 
399  foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
400  break;
401  }
402  }
403 
404  if (!foundDecorationAttr)
405  return emitError(unknownLoc, "unimplemented decoration support for "
406  "function argument with result <id> ")
407  << argID;
408 
409  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
410  foundDecorationAttr);
411  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
412  return success();
413 }
414 
416 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
417  if (curFunction) {
418  return emitError(unknownLoc, "found function inside function");
419  }
420 
421  // Get the result type
422  if (operands.size() != 4) {
423  return emitError(unknownLoc, "OpFunction must have 4 parameters");
424  }
425  Type resultType = getType(operands[0]);
426  if (!resultType) {
427  return emitError(unknownLoc, "undefined result type from <id> ")
428  << operands[0];
429  }
430 
431  uint32_t fnID = operands[1];
432  if (funcMap.count(fnID)) {
433  return emitError(unknownLoc, "duplicate function definition/declaration");
434  }
435 
436  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
437  if (!fnControl) {
438  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
439  }
440 
441  Type fnType = getType(operands[3]);
442  if (!fnType || !isa<FunctionType>(fnType)) {
443  return emitError(unknownLoc, "unknown function type from <id> ")
444  << operands[3];
445  }
446  auto functionType = cast<FunctionType>(fnType);
447 
448  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
449  (functionType.getNumResults() == 1 &&
450  functionType.getResult(0) != resultType)) {
451  return emitError(unknownLoc, "mismatch in function type ")
452  << functionType << " and return type " << resultType << " specified";
453  }
454 
455  std::string fnName = getFunctionSymbol(fnID);
456  auto funcOp = opBuilder.create<spirv::FuncOp>(
457  unknownLoc, fnName, functionType, fnControl.value());
458  // Processing other function attributes.
459  if (decorations.count(fnID)) {
460  for (auto attr : decorations[fnID].getAttrs()) {
461  funcOp->setAttr(attr.getName(), attr.getValue());
462  }
463  }
464  curFunction = funcMap[fnID] = funcOp;
465  auto *entryBlock = funcOp.addEntryBlock();
466  LLVM_DEBUG({
467  logger.startLine()
468  << "//===-------------------------------------------===//\n";
469  logger.startLine() << "[fn] name: " << fnName << "\n";
470  logger.startLine() << "[fn] type: " << fnType << "\n";
471  logger.startLine() << "[fn] ID: " << fnID << "\n";
472  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
473  logger.indent();
474  });
475 
476  SmallVector<Attribute> argAttrs;
477  argAttrs.resize(functionType.getNumInputs());
478 
479  // Parse the op argument instructions
480  if (functionType.getNumInputs()) {
481  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
482  auto argType = functionType.getInput(i);
483  spirv::Opcode opcode = spirv::Opcode::OpNop;
484  ArrayRef<uint32_t> operands;
485  if (failed(sliceInstruction(opcode, operands,
486  spirv::Opcode::OpFunctionParameter))) {
487  return failure();
488  }
489  if (opcode != spirv::Opcode::OpFunctionParameter) {
490  return emitError(
491  unknownLoc,
492  "missing OpFunctionParameter instruction for argument ")
493  << i;
494  }
495  if (operands.size() != 2) {
496  return emitError(
497  unknownLoc,
498  "expected result type and result <id> for OpFunctionParameter");
499  }
500  auto argDefinedType = getType(operands[0]);
501  if (!argDefinedType || argDefinedType != argType) {
502  return emitError(unknownLoc,
503  "mismatch in argument type between function type "
504  "definition ")
505  << functionType << " and argument type definition "
506  << argDefinedType << " at argument " << i;
507  }
508  if (getValue(operands[1])) {
509  return emitError(unknownLoc, "duplicate definition of result <id> ")
510  << operands[1];
511  }
512  if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
513  return failure();
514  }
515 
516  auto argValue = funcOp.getArgument(i);
517  valueMap[operands[1]] = argValue;
518  }
519  }
520 
521  if (llvm::any_of(argAttrs, [](Attribute attr) {
522  auto argAttr = cast<DictionaryAttr>(attr);
523  return !argAttr.empty();
524  }))
525  funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
526 
527  // entryBlock is needed to access the arguments, Once that is done, we can
528  // erase the block for functions with 'Import' LinkageAttributes, since these
529  // are essentially function declarations, so they have no body.
530  auto linkageAttr = funcOp.getLinkageAttributes();
531  auto hasImportLinkage =
532  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
533  spirv::LinkageType::Import);
534  if (hasImportLinkage)
535  funcOp.eraseBody();
536 
537  // RAII guard to reset the insertion point to the module's region after
538  // deserializing the body of this function.
539  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
540 
541  spirv::Opcode opcode = spirv::Opcode::OpNop;
542  ArrayRef<uint32_t> instOperands;
543 
544  // Special handling for the entry block. We need to make sure it starts with
545  // an OpLabel instruction. The entry block takes the same parameters as the
546  // function. All other blocks do not take any parameter. We have already
547  // created the entry block, here we need to register it to the correct label
548  // <id>.
549  if (failed(sliceInstruction(opcode, instOperands,
550  spirv::Opcode::OpFunctionEnd))) {
551  return failure();
552  }
553  if (opcode == spirv::Opcode::OpFunctionEnd) {
554  return processFunctionEnd(instOperands);
555  }
556  if (opcode != spirv::Opcode::OpLabel) {
557  return emitError(unknownLoc, "a basic block must start with OpLabel");
558  }
559  if (instOperands.size() != 1) {
560  return emitError(unknownLoc, "OpLabel should only have result <id>");
561  }
562  blockMap[instOperands[0]] = entryBlock;
563  if (failed(processLabel(instOperands))) {
564  return failure();
565  }
566 
567  // Then process all the other instructions in the function until we hit
568  // OpFunctionEnd.
569  while (succeeded(sliceInstruction(opcode, instOperands,
570  spirv::Opcode::OpFunctionEnd)) &&
571  opcode != spirv::Opcode::OpFunctionEnd) {
572  if (failed(processInstruction(opcode, instOperands))) {
573  return failure();
574  }
575  }
576  if (opcode != spirv::Opcode::OpFunctionEnd) {
577  return failure();
578  }
579 
580  return processFunctionEnd(instOperands);
581 }
582 
584 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
585  // Process OpFunctionEnd.
586  if (!operands.empty()) {
587  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
588  }
589 
590  // Wire up block arguments from OpPhi instructions.
591  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
592  // ops.
593  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
594  return failure();
595  }
596 
597  curBlock = nullptr;
598  curFunction = std::nullopt;
599 
600  LLVM_DEBUG({
601  logger.unindent();
602  logger.startLine()
603  << "//===-------------------------------------------===//\n";
604  });
605  return success();
606 }
607 
608 std::optional<std::pair<Attribute, Type>>
609 spirv::Deserializer::getConstant(uint32_t id) {
610  auto constIt = constantMap.find(id);
611  if (constIt == constantMap.end())
612  return std::nullopt;
613  return constIt->getSecond();
614 }
615 
616 std::optional<spirv::SpecConstOperationMaterializationInfo>
617 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
618  auto constIt = specConstOperationMap.find(id);
619  if (constIt == specConstOperationMap.end())
620  return std::nullopt;
621  return constIt->getSecond();
622 }
623 
624 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
625  auto funcName = nameMap.lookup(id).str();
626  if (funcName.empty()) {
627  funcName = "spirv_fn_" + std::to_string(id);
628  }
629  return funcName;
630 }
631 
632 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
633  auto constName = nameMap.lookup(id).str();
634  if (constName.empty()) {
635  constName = "spirv_spec_const_" + std::to_string(id);
636  }
637  return constName;
638 }
639 
640 spirv::SpecConstantOp
641 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
642  TypedAttr defaultValue) {
643  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
644  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
645  defaultValue);
646  if (decorations.count(resultID)) {
647  for (auto attr : decorations[resultID].getAttrs())
648  op->setAttr(attr.getName(), attr.getValue());
649  }
650  specConstMap[resultID] = op;
651  return op;
652 }
653 
655 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
656  unsigned wordIndex = 0;
657  if (operands.size() < 3) {
658  return emitError(
659  unknownLoc,
660  "OpVariable needs at least 3 operands, type, <id> and storage class");
661  }
662 
663  // Result Type.
664  auto type = getType(operands[wordIndex]);
665  if (!type) {
666  return emitError(unknownLoc, "unknown result type <id> : ")
667  << operands[wordIndex];
668  }
669  auto ptrType = dyn_cast<spirv::PointerType>(type);
670  if (!ptrType) {
671  return emitError(unknownLoc,
672  "expected a result type <id> to be a spirv.ptr, found : ")
673  << type;
674  }
675  wordIndex++;
676 
677  // Result <id>.
678  auto variableID = operands[wordIndex];
679  auto variableName = nameMap.lookup(variableID).str();
680  if (variableName.empty()) {
681  variableName = "spirv_var_" + std::to_string(variableID);
682  }
683  wordIndex++;
684 
685  // Storage class.
686  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
687  if (ptrType.getStorageClass() != storageClass) {
688  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
689  << type << " and that specified in OpVariable instruction : "
690  << stringifyStorageClass(storageClass);
691  }
692  wordIndex++;
693 
694  // Initializer.
695  FlatSymbolRefAttr initializer = nullptr;
696 
697  if (wordIndex < operands.size()) {
698  Operation *op = nullptr;
699 
700  if (auto initOp = getGlobalVariable(operands[wordIndex]))
701  op = initOp;
702  else if (auto initOp = getSpecConstant(operands[wordIndex]))
703  op = initOp;
704  else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
705  op = initOp;
706  else
707  return emitError(unknownLoc, "unknown <id> ")
708  << operands[wordIndex] << "used as initializer";
709 
710  initializer = SymbolRefAttr::get(op);
711  wordIndex++;
712  }
713  if (wordIndex != operands.size()) {
714  return emitError(unknownLoc,
715  "found more operands than expected when deserializing "
716  "OpVariable instruction, only ")
717  << wordIndex << " of " << operands.size() << " processed";
718  }
719  auto loc = createFileLineColLoc(opBuilder);
720  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
721  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
722  initializer);
723 
724  // Decorations.
725  if (decorations.count(variableID)) {
726  for (auto attr : decorations[variableID].getAttrs())
727  varOp->setAttr(attr.getName(), attr.getValue());
728  }
729  globalVariableMap[variableID] = varOp;
730  return success();
731 }
732 
733 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
734  auto constInfo = getConstant(id);
735  if (!constInfo) {
736  return nullptr;
737  }
738  return dyn_cast<IntegerAttr>(constInfo->first);
739 }
740 
741 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
742  if (operands.size() < 2) {
743  return emitError(unknownLoc, "OpName needs at least 2 operands");
744  }
745  if (!nameMap.lookup(operands[0]).empty()) {
746  return emitError(unknownLoc, "duplicate name found for result <id> ")
747  << operands[0];
748  }
749  unsigned wordIndex = 1;
750  StringRef name = decodeStringLiteral(operands, wordIndex);
751  if (wordIndex != operands.size()) {
752  return emitError(unknownLoc,
753  "unexpected trailing words in OpName instruction");
754  }
755  nameMap[operands[0]] = name;
756  return success();
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // Type
761 //===----------------------------------------------------------------------===//
762 
763 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
764  ArrayRef<uint32_t> operands) {
765  if (operands.empty()) {
766  return emitError(unknownLoc, "type instruction with opcode ")
767  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
768  }
769 
770  /// TODO: Types might be forward declared in some instructions and need to be
771  /// handled appropriately.
772  if (typeMap.count(operands[0])) {
773  return emitError(unknownLoc, "duplicate definition for result <id> ")
774  << operands[0];
775  }
776 
777  switch (opcode) {
778  case spirv::Opcode::OpTypeVoid:
779  if (operands.size() != 1)
780  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
781  typeMap[operands[0]] = opBuilder.getNoneType();
782  break;
783  case spirv::Opcode::OpTypeBool:
784  if (operands.size() != 1)
785  return emitError(unknownLoc, "OpTypeBool must have no parameters");
786  typeMap[operands[0]] = opBuilder.getI1Type();
787  break;
788  case spirv::Opcode::OpTypeInt: {
789  if (operands.size() != 3)
790  return emitError(
791  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
792 
793  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
794  // to preserve or validate.
795  // 0 indicates unsigned, or no signedness semantics
796  // 1 indicates signed semantics."
797  //
798  // So we cannot differentiate signless and unsigned integers; always use
799  // signless semantics for such cases.
800  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
801  : IntegerType::SignednessSemantics::Signless;
802  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
803  } break;
804  case spirv::Opcode::OpTypeFloat: {
805  if (operands.size() != 2)
806  return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
807 
808  Type floatTy;
809  switch (operands[1]) {
810  case 16:
811  floatTy = opBuilder.getF16Type();
812  break;
813  case 32:
814  floatTy = opBuilder.getF32Type();
815  break;
816  case 64:
817  floatTy = opBuilder.getF64Type();
818  break;
819  default:
820  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
821  << operands[1];
822  }
823  typeMap[operands[0]] = floatTy;
824  } break;
825  case spirv::Opcode::OpTypeVector: {
826  if (operands.size() != 3) {
827  return emitError(
828  unknownLoc,
829  "OpTypeVector must have element type and count parameters");
830  }
831  Type elementTy = getType(operands[1]);
832  if (!elementTy) {
833  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
834  << operands[1];
835  }
836  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
837  } break;
838  case spirv::Opcode::OpTypePointer: {
839  return processOpTypePointer(operands);
840  } break;
841  case spirv::Opcode::OpTypeArray:
842  return processArrayType(operands);
843  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
844  return processCooperativeMatrixTypeKHR(operands);
845  case spirv::Opcode::OpTypeFunction:
846  return processFunctionType(operands);
847  case spirv::Opcode::OpTypeJointMatrixINTEL:
848  return processJointMatrixType(operands);
849  case spirv::Opcode::OpTypeImage:
850  return processImageType(operands);
851  case spirv::Opcode::OpTypeSampledImage:
852  return processSampledImageType(operands);
853  case spirv::Opcode::OpTypeRuntimeArray:
854  return processRuntimeArrayType(operands);
855  case spirv::Opcode::OpTypeStruct:
856  return processStructType(operands);
857  case spirv::Opcode::OpTypeMatrix:
858  return processMatrixType(operands);
859  default:
860  return emitError(unknownLoc, "unhandled type instruction");
861  }
862  return success();
863 }
864 
866 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
867  if (operands.size() != 3)
868  return emitError(unknownLoc, "OpTypePointer must have two parameters");
869 
870  auto pointeeType = getType(operands[2]);
871  if (!pointeeType)
872  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
873  << operands[2];
874 
875  uint32_t typePointerID = operands[0];
876  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
877  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
878 
879  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
880  deferredStructIt != std::end(deferredStructTypesInfos);) {
881  for (auto *unresolvedMemberIt =
882  std::begin(deferredStructIt->unresolvedMemberTypes);
883  unresolvedMemberIt !=
884  std::end(deferredStructIt->unresolvedMemberTypes);) {
885  if (unresolvedMemberIt->first == typePointerID) {
886  // The newly constructed pointer type can resolve one of the
887  // deferred struct type members; update the memberTypes list and
888  // clean the unresolvedMemberTypes list accordingly.
889  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
890  typeMap[typePointerID];
891  unresolvedMemberIt =
892  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
893  } else {
894  ++unresolvedMemberIt;
895  }
896  }
897 
898  if (deferredStructIt->unresolvedMemberTypes.empty()) {
899  // All deferred struct type members are now resolved, set the struct body.
900  auto structType = deferredStructIt->deferredStructType;
901 
902  assert(structType && "expected a spirv::StructType");
903  assert(structType.isIdentified() && "expected an indentified struct");
904 
905  if (failed(structType.trySetBody(
906  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
907  deferredStructIt->memberDecorationsInfo)))
908  return failure();
909 
910  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
911  } else {
912  ++deferredStructIt;
913  }
914  }
915 
916  return success();
917 }
918 
920 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
921  if (operands.size() != 3) {
922  return emitError(unknownLoc,
923  "OpTypeArray must have element type and count parameters");
924  }
925 
926  Type elementTy = getType(operands[1]);
927  if (!elementTy) {
928  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
929  << operands[1];
930  }
931 
932  unsigned count = 0;
933  // TODO: The count can also come frome a specialization constant.
934  auto countInfo = getConstant(operands[2]);
935  if (!countInfo) {
936  return emitError(unknownLoc, "OpTypeArray count <id> ")
937  << operands[2] << "can only come from normal constant right now";
938  }
939 
940  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
941  count = intVal.getValue().getZExtValue();
942  } else {
943  return emitError(unknownLoc, "OpTypeArray count must come from a "
944  "scalar integer constant instruction");
945  }
946 
947  typeMap[operands[0]] = spirv::ArrayType::get(
948  elementTy, count, typeDecorations.lookup(operands[0]));
949  return success();
950 }
951 
953 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
954  assert(!operands.empty() && "No operands for processing function type");
955  if (operands.size() == 1) {
956  return emitError(unknownLoc, "missing return type for OpTypeFunction");
957  }
958  auto returnType = getType(operands[1]);
959  if (!returnType) {
960  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
961  }
962  SmallVector<Type, 1> argTypes;
963  for (size_t i = 2, e = operands.size(); i < e; ++i) {
964  auto ty = getType(operands[i]);
965  if (!ty) {
966  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
967  }
968  argTypes.push_back(ty);
969  }
970  ArrayRef<Type> returnTypes;
971  if (!isVoidType(returnType)) {
972  returnTypes = llvm::ArrayRef(returnType);
973  }
974  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
975  return success();
976 }
977 
978 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
979  ArrayRef<uint32_t> operands) {
980  if (operands.size() != 6) {
981  return emitError(unknownLoc,
982  "OpTypeCooperativeMatrixKHR must have element type, "
983  "scope, row and column parameters, and use");
984  }
985 
986  Type elementTy = getType(operands[1]);
987  if (!elementTy) {
988  return emitError(unknownLoc,
989  "OpTypeCooperativeMatrixKHR references undefined <id> ")
990  << operands[1];
991  }
992 
993  std::optional<spirv::Scope> scope =
994  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
995  if (!scope) {
996  return emitError(
997  unknownLoc,
998  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
999  << operands[2];
1000  }
1001 
1002  unsigned rows = getConstantInt(operands[3]).getInt();
1003  unsigned columns = getConstantInt(operands[4]).getInt();
1004 
1005  std::optional<spirv::CooperativeMatrixUseKHR> use =
1006  spirv::symbolizeCooperativeMatrixUseKHR(
1007  getConstantInt(operands[5]).getInt());
1008  if (!use) {
1009  return emitError(
1010  unknownLoc,
1011  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1012  << operands[5];
1013  }
1014 
1015  typeMap[operands[0]] =
1016  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1017  return success();
1018 }
1019 
1021 spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
1022  if (operands.size() != 6) {
1023  return emitError(unknownLoc, "OpTypeJointMatrix must have element "
1024  "type and row x column parameters");
1025  }
1026 
1027  Type elementTy = getType(operands[1]);
1028  if (!elementTy) {
1029  return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
1030  << operands[1];
1031  }
1032 
1033  auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
1034  if (!scope) {
1035  return emitError(unknownLoc,
1036  "OpTypeJointMatrix references undefined scope <id> ")
1037  << operands[5];
1038  }
1039  auto matrixLayout =
1040  spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
1041  if (!matrixLayout) {
1042  return emitError(unknownLoc,
1043  "OpTypeJointMatrix references undefined scope <id> ")
1044  << operands[4];
1045  }
1046  unsigned rows = getConstantInt(operands[2]).getInt();
1047  unsigned columns = getConstantInt(operands[3]).getInt();
1048 
1049  typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
1050  elementTy, scope.value(), rows, columns, matrixLayout.value());
1051  return success();
1052 }
1053 
1055 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1056  if (operands.size() != 2) {
1057  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1058  }
1059  Type memberType = getType(operands[1]);
1060  if (!memberType) {
1061  return emitError(unknownLoc,
1062  "OpTypeRuntimeArray references undefined <id> ")
1063  << operands[1];
1064  }
1065  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1066  memberType, typeDecorations.lookup(operands[0]));
1067  return success();
1068 }
1069 
1071 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1072  // TODO: Find a way to handle identified structs when debug info is stripped.
1073 
1074  if (operands.empty()) {
1075  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1076  }
1077 
1078  if (operands.size() == 1) {
1079  // Handle empty struct.
1080  typeMap[operands[0]] =
1081  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1082  return success();
1083  }
1084 
1085  // First element is operand ID, second element is member index in the struct.
1086  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1087  SmallVector<Type, 4> memberTypes;
1088 
1089  for (auto op : llvm::drop_begin(operands, 1)) {
1090  Type memberType = getType(op);
1091  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1092 
1093  if (!memberType && !typeForwardPtr)
1094  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1095  << op;
1096 
1097  if (!memberType)
1098  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1099 
1100  memberTypes.push_back(memberType);
1101  }
1102 
1105  if (memberDecorationMap.count(operands[0])) {
1106  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1107  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1108  if (allMemberDecorations.count(memberIndex)) {
1109  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1110  // Check for offset.
1111  if (memberDecoration.first == spirv::Decoration::Offset) {
1112  // If offset info is empty, resize to the number of members;
1113  if (offsetInfo.empty()) {
1114  offsetInfo.resize(memberTypes.size());
1115  }
1116  offsetInfo[memberIndex] = memberDecoration.second[0];
1117  } else {
1118  if (!memberDecoration.second.empty()) {
1119  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1120  memberDecoration.first,
1121  memberDecoration.second[0]);
1122  } else {
1123  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1124  memberDecoration.first, 0);
1125  }
1126  }
1127  }
1128  }
1129  }
1130  }
1131 
1132  uint32_t structID = operands[0];
1133  std::string structIdentifier = nameMap.lookup(structID).str();
1134 
1135  if (structIdentifier.empty()) {
1136  assert(unresolvedMemberTypes.empty() &&
1137  "didn't expect unresolved member types");
1138  typeMap[structID] =
1139  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1140  } else {
1141  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1142  typeMap[structID] = structTy;
1143 
1144  if (!unresolvedMemberTypes.empty())
1145  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1146  memberTypes, offsetInfo,
1147  memberDecorationsInfo});
1148  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1149  memberDecorationsInfo)))
1150  return failure();
1151  }
1152 
1153  // TODO: Update StructType to have member name as attribute as
1154  // well.
1155  return success();
1156 }
1157 
1159 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1160  if (operands.size() != 3) {
1161  // Three operands are needed: result_id, column_type, and column_count
1162  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1163  " (result_id, column_type, and column_count)");
1164  }
1165  // Matrix columns must be of vector type
1166  Type elementTy = getType(operands[1]);
1167  if (!elementTy) {
1168  return emitError(unknownLoc,
1169  "OpTypeMatrix references undefined column type.")
1170  << operands[1];
1171  }
1172 
1173  uint32_t colsCount = operands[2];
1174  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1175  return success();
1176 }
1177 
1179 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1180  if (operands.size() != 2)
1181  return emitError(unknownLoc,
1182  "OpTypeForwardPointer instruction must have two operands");
1183 
1184  typeForwardPointerIDs.insert(operands[0]);
1185  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1186  // instruction that defines the actual type.
1187 
1188  return success();
1189 }
1190 
1192 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1193  // TODO: Add support for Access Qualifier.
1194  if (operands.size() != 8)
1195  return emitError(
1196  unknownLoc,
1197  "OpTypeImage with non-eight operands are not supported yet");
1198 
1199  Type elementTy = getType(operands[1]);
1200  if (!elementTy)
1201  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1202  << operands[1];
1203 
1204  auto dim = spirv::symbolizeDim(operands[2]);
1205  if (!dim)
1206  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1207  << operands[2];
1208 
1209  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1210  if (!depthInfo)
1211  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1212  << operands[3];
1213 
1214  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1215  if (!arrayedInfo)
1216  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1217  << operands[4];
1218 
1219  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1220  if (!samplingInfo)
1221  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1222 
1223  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1224  if (!samplerUseInfo)
1225  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1226  << operands[6];
1227 
1228  auto format = spirv::symbolizeImageFormat(operands[7]);
1229  if (!format)
1230  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1231  << operands[7];
1232 
1233  typeMap[operands[0]] = spirv::ImageType::get(
1234  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1235  samplingInfo.value(), samplerUseInfo.value(), format.value());
1236  return success();
1237 }
1238 
1240 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1241  if (operands.size() != 2)
1242  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1243 
1244  Type elementTy = getType(operands[1]);
1245  if (!elementTy)
1246  return emitError(unknownLoc,
1247  "OpTypeSampledImage references undefined <id>: ")
1248  << operands[1];
1249 
1250  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1251  return success();
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // Constant
1256 //===----------------------------------------------------------------------===//
1257 
1258 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1259  bool isSpec) {
1260  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1261 
1262  if (operands.size() < 2) {
1263  return emitError(unknownLoc)
1264  << opname << " must have type <id> and result <id>";
1265  }
1266  if (operands.size() < 3) {
1267  return emitError(unknownLoc)
1268  << opname << " must have at least 1 more parameter";
1269  }
1270 
1271  Type resultType = getType(operands[0]);
1272  if (!resultType) {
1273  return emitError(unknownLoc, "undefined result type from <id> ")
1274  << operands[0];
1275  }
1276 
1277  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1278  if (bitwidth == 64) {
1279  if (operands.size() == 4) {
1280  return success();
1281  }
1282  return emitError(unknownLoc)
1283  << opname << " should have 2 parameters for 64-bit values";
1284  }
1285  if (bitwidth <= 32) {
1286  if (operands.size() == 3) {
1287  return success();
1288  }
1289 
1290  return emitError(unknownLoc)
1291  << opname
1292  << " should have 1 parameter for values with no more than 32 bits";
1293  }
1294  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1295  << bitwidth;
1296  };
1297 
1298  auto resultID = operands[1];
1299 
1300  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1301  auto bitwidth = intType.getWidth();
1302  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1303  return failure();
1304  }
1305 
1306  APInt value;
1307  if (bitwidth == 64) {
1308  // 64-bit integers are represented with two SPIR-V words. According to
1309  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1310  // literal’s low-order words appear first."
1311  struct DoubleWord {
1312  uint32_t word1;
1313  uint32_t word2;
1314  } words = {operands[2], operands[3]};
1315  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1316  } else if (bitwidth <= 32) {
1317  value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1318  }
1319 
1320  auto attr = opBuilder.getIntegerAttr(intType, value);
1321 
1322  if (isSpec) {
1323  createSpecConstant(unknownLoc, resultID, attr);
1324  } else {
1325  // For normal constants, we just record the attribute (and its type) for
1326  // later materialization at use sites.
1327  constantMap.try_emplace(resultID, attr, intType);
1328  }
1329 
1330  return success();
1331  }
1332 
1333  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1334  auto bitwidth = floatType.getWidth();
1335  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1336  return failure();
1337  }
1338 
1339  APFloat value(0.f);
1340  if (floatType.isF64()) {
1341  // Double values are represented with two SPIR-V words. According to
1342  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1343  // literal’s low-order words appear first."
1344  struct DoubleWord {
1345  uint32_t word1;
1346  uint32_t word2;
1347  } words = {operands[2], operands[3]};
1348  value = APFloat(llvm::bit_cast<double>(words));
1349  } else if (floatType.isF32()) {
1350  value = APFloat(llvm::bit_cast<float>(operands[2]));
1351  } else if (floatType.isF16()) {
1352  APInt data(16, operands[2]);
1353  value = APFloat(APFloat::IEEEhalf(), data);
1354  }
1355 
1356  auto attr = opBuilder.getFloatAttr(floatType, value);
1357  if (isSpec) {
1358  createSpecConstant(unknownLoc, resultID, attr);
1359  } else {
1360  // For normal constants, we just record the attribute (and its type) for
1361  // later materialization at use sites.
1362  constantMap.try_emplace(resultID, attr, floatType);
1363  }
1364 
1365  return success();
1366  }
1367 
1368  return emitError(unknownLoc, "OpConstant can only generate values of "
1369  "scalar integer or floating-point type");
1370 }
1371 
1372 LogicalResult spirv::Deserializer::processConstantBool(
1373  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1374  if (operands.size() != 2) {
1375  return emitError(unknownLoc, "Op")
1376  << (isSpec ? "Spec" : "") << "Constant"
1377  << (isTrue ? "True" : "False")
1378  << " must have type <id> and result <id>";
1379  }
1380 
1381  auto attr = opBuilder.getBoolAttr(isTrue);
1382  auto resultID = operands[1];
1383  if (isSpec) {
1384  createSpecConstant(unknownLoc, resultID, attr);
1385  } else {
1386  // For normal constants, we just record the attribute (and its type) for
1387  // later materialization at use sites.
1388  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1389  }
1390 
1391  return success();
1392 }
1393 
1395 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1396  if (operands.size() < 2) {
1397  return emitError(unknownLoc,
1398  "OpConstantComposite must have type <id> and result <id>");
1399  }
1400  if (operands.size() < 3) {
1401  return emitError(unknownLoc,
1402  "OpConstantComposite must have at least 1 parameter");
1403  }
1404 
1405  Type resultType = getType(operands[0]);
1406  if (!resultType) {
1407  return emitError(unknownLoc, "undefined result type from <id> ")
1408  << operands[0];
1409  }
1410 
1411  SmallVector<Attribute, 4> elements;
1412  elements.reserve(operands.size() - 2);
1413  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1414  auto elementInfo = getConstant(operands[i]);
1415  if (!elementInfo) {
1416  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1417  << operands[i] << " must come from a normal constant";
1418  }
1419  elements.push_back(elementInfo->first);
1420  }
1421 
1422  auto resultID = operands[1];
1423  if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1424  auto attr = DenseElementsAttr::get(vectorType, elements);
1425  // For normal constants, we just record the attribute (and its type) for
1426  // later materialization at use sites.
1427  constantMap.try_emplace(resultID, attr, resultType);
1428  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1429  auto attr = opBuilder.getArrayAttr(elements);
1430  constantMap.try_emplace(resultID, attr, resultType);
1431  } else {
1432  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1433  << resultType;
1434  }
1435 
1436  return success();
1437 }
1438 
1440 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1441  if (operands.size() < 2) {
1442  return emitError(unknownLoc,
1443  "OpConstantComposite must have type <id> and result <id>");
1444  }
1445  if (operands.size() < 3) {
1446  return emitError(unknownLoc,
1447  "OpConstantComposite must have at least 1 parameter");
1448  }
1449 
1450  Type resultType = getType(operands[0]);
1451  if (!resultType) {
1452  return emitError(unknownLoc, "undefined result type from <id> ")
1453  << operands[0];
1454  }
1455 
1456  auto resultID = operands[1];
1457  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1458 
1459  SmallVector<Attribute, 4> elements;
1460  elements.reserve(operands.size() - 2);
1461  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1462  auto elementInfo = getSpecConstant(operands[i]);
1463  elements.push_back(SymbolRefAttr::get(elementInfo));
1464  }
1465 
1466  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1467  unknownLoc, TypeAttr::get(resultType), symName,
1468  opBuilder.getArrayAttr(elements));
1469  specConstCompositeMap[resultID] = op;
1470 
1471  return success();
1472 }
1473 
1475 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1476  if (operands.size() < 3)
1477  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1478  "result <id>, and operand opcode");
1479 
1480  uint32_t resultTypeID = operands[0];
1481 
1482  if (!getType(resultTypeID))
1483  return emitError(unknownLoc, "undefined result type from <id> ")
1484  << resultTypeID;
1485 
1486  uint32_t resultID = operands[1];
1487  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1488  auto emplaceResult = specConstOperationMap.try_emplace(
1489  resultID,
1490  SpecConstOperationMaterializationInfo{
1491  enclosedOpcode, resultTypeID,
1492  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1493 
1494  if (!emplaceResult.second)
1495  return emitError(unknownLoc, "value with <id>: ")
1496  << resultID << " is probably defined before.";
1497 
1498  return success();
1499 }
1500 
1501 Value spirv::Deserializer::materializeSpecConstantOperation(
1502  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1503  ArrayRef<uint32_t> enclosedOpOperands) {
1504 
1505  Type resultType = getType(resultTypeID);
1506 
1507  // Instructions wrapped by OpSpecConstantOp need an ID for their
1508  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1509  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1510  // ID in that map is assigned to the result of the enclosed instruction. Note
1511  // that there is no need to update this fake ID since we only need to
1512  // reference the created Value for the enclosed op from the spv::YieldOp
1513  // created later in this method (both of which are the only values in their
1514  // region: the SpecConstantOperation's region). If we encounter another
1515  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1516  // previous Value assigned to it isn't visible in the current scope anyway.
1517  DenseMap<uint32_t, Value> newValueMap;
1518  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1519  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1520 
1521  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1522  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1523  enclosedOpResultTypeAndOperands.push_back(fakeID);
1524  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1525  enclosedOpOperands.end());
1526 
1527  // Process enclosed instruction before creating the enclosing
1528  // specConstantOperation (and its region). This way, references to constants,
1529  // global variables, and spec constants will be materialized outside the new
1530  // op's region. For more info, see Deserializer::getValue's implementation.
1531  if (failed(
1532  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1533  return Value();
1534 
1535  // Since the enclosed op is emitted in the current block, split it in a
1536  // separate new block.
1537  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1538 
1539  auto loc = createFileLineColLoc(opBuilder);
1540  auto specConstOperationOp =
1541  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1542 
1543  Region &body = specConstOperationOp.getBody();
1544  // Move the new block into SpecConstantOperation's body.
1545  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1546  Region::iterator(enclosedBlock));
1547  Block &block = body.back();
1548 
1549  // RAII guard to reset the insertion point to the module's region after
1550  // deserializing the body of the specConstantOperation.
1551  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1552  opBuilder.setInsertionPointToEnd(&block);
1553 
1554  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1555  return specConstOperationOp.getResult();
1556 }
1557 
1559 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1560  if (operands.size() != 2) {
1561  return emitError(unknownLoc,
1562  "OpConstantNull must have type <id> and result <id>");
1563  }
1564 
1565  Type resultType = getType(operands[0]);
1566  if (!resultType) {
1567  return emitError(unknownLoc, "undefined result type from <id> ")
1568  << operands[0];
1569  }
1570 
1571  auto resultID = operands[1];
1572  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1573  auto attr = opBuilder.getZeroAttr(resultType);
1574  // For normal constants, we just record the attribute (and its type) for
1575  // later materialization at use sites.
1576  constantMap.try_emplace(resultID, attr, resultType);
1577  return success();
1578  }
1579 
1580  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1581  << resultType;
1582 }
1583 
1584 //===----------------------------------------------------------------------===//
1585 // Control flow
1586 //===----------------------------------------------------------------------===//
1587 
1588 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1589  if (auto *block = getBlock(id)) {
1590  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1591  << " @ " << block << "\n");
1592  return block;
1593  }
1594 
1595  // We don't know where this block will be placed finally (in a
1596  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1597  // function for now and sort out the proper place later.
1598  auto *block = curFunction->addBlock();
1599  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1600  << " @ " << block << "\n");
1601  return blockMap[id] = block;
1602 }
1603 
1604 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1605  if (!curBlock) {
1606  return emitError(unknownLoc, "OpBranch must appear inside a block");
1607  }
1608 
1609  if (operands.size() != 1) {
1610  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1611  }
1612 
1613  auto *target = getOrCreateBlock(operands[0]);
1614  auto loc = createFileLineColLoc(opBuilder);
1615  // The preceding instruction for the OpBranch instruction could be an
1616  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1617  // the same OpLine information.
1618  opBuilder.create<spirv::BranchOp>(loc, target);
1619 
1620  clearDebugLine();
1621  return success();
1622 }
1623 
1625 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1626  if (!curBlock) {
1627  return emitError(unknownLoc,
1628  "OpBranchConditional must appear inside a block");
1629  }
1630 
1631  if (operands.size() != 3 && operands.size() != 5) {
1632  return emitError(unknownLoc,
1633  "OpBranchConditional must have condition, true label, "
1634  "false label, and optionally two branch weights");
1635  }
1636 
1637  auto condition = getValue(operands[0]);
1638  auto *trueBlock = getOrCreateBlock(operands[1]);
1639  auto *falseBlock = getOrCreateBlock(operands[2]);
1640 
1641  std::optional<std::pair<uint32_t, uint32_t>> weights;
1642  if (operands.size() == 5) {
1643  weights = std::make_pair(operands[3], operands[4]);
1644  }
1645  // The preceding instruction for the OpBranchConditional instruction could be
1646  // an OpSelectionMerge instruction, in this case they will have the same
1647  // OpLine information.
1648  auto loc = createFileLineColLoc(opBuilder);
1649  opBuilder.create<spirv::BranchConditionalOp>(
1650  loc, condition, trueBlock,
1651  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1652  /*falseArguments=*/ArrayRef<Value>(), weights);
1653 
1654  clearDebugLine();
1655  return success();
1656 }
1657 
1658 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1659  if (!curFunction) {
1660  return emitError(unknownLoc, "OpLabel must appear inside a function");
1661  }
1662 
1663  if (operands.size() != 1) {
1664  return emitError(unknownLoc, "OpLabel should only have result <id>");
1665  }
1666 
1667  auto labelID = operands[0];
1668  // We may have forward declared this block.
1669  auto *block = getOrCreateBlock(labelID);
1670  LLVM_DEBUG(logger.startLine()
1671  << "[block] populating block " << block << "\n");
1672  // If we have seen this block, make sure it was just a forward declaration.
1673  assert(block->empty() && "re-deserialize the same block!");
1674 
1675  opBuilder.setInsertionPointToStart(block);
1676  blockMap[labelID] = curBlock = block;
1677 
1678  return success();
1679 }
1680 
1682 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1683  if (!curBlock) {
1684  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1685  }
1686 
1687  if (operands.size() < 2) {
1688  return emitError(
1689  unknownLoc,
1690  "OpSelectionMerge must specify merge target and selection control");
1691  }
1692 
1693  auto *mergeBlock = getOrCreateBlock(operands[0]);
1694  auto loc = createFileLineColLoc(opBuilder);
1695  auto selectionControl = operands[1];
1696 
1697  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1698  .second) {
1699  return emitError(
1700  unknownLoc,
1701  "a block cannot have more than one OpSelectionMerge instruction");
1702  }
1703 
1704  return success();
1705 }
1706 
1708 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1709  if (!curBlock) {
1710  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1711  }
1712 
1713  if (operands.size() < 3) {
1714  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1715  "continue target and loop control");
1716  }
1717 
1718  auto *mergeBlock = getOrCreateBlock(operands[0]);
1719  auto *continueBlock = getOrCreateBlock(operands[1]);
1720  auto loc = createFileLineColLoc(opBuilder);
1721  uint32_t loopControl = operands[2];
1722 
1723  if (!blockMergeInfo
1724  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1725  .second) {
1726  return emitError(
1727  unknownLoc,
1728  "a block cannot have more than one OpLoopMerge instruction");
1729  }
1730 
1731  return success();
1732 }
1733 
1734 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1735  if (!curBlock) {
1736  return emitError(unknownLoc, "OpPhi must appear in a block");
1737  }
1738 
1739  if (operands.size() < 4) {
1740  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1741  "and variable-parent pairs");
1742  }
1743 
1744  // Create a block argument for this OpPhi instruction.
1745  Type blockArgType = getType(operands[0]);
1746  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1747  valueMap[operands[1]] = blockArg;
1748  LLVM_DEBUG(logger.startLine()
1749  << "[phi] created block argument " << blockArg
1750  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1751 
1752  // For each (value, predecessor) pair, insert the value to the predecessor's
1753  // blockPhiInfo entry so later we can fix the block argument there.
1754  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1755  uint32_t value = operands[i];
1756  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1757  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1758  blockPhiInfo[predecessorTargetPair].push_back(value);
1759  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1760  << " with arg id = " << value << "\n");
1761  }
1762 
1763  return success();
1764 }
1765 
1766 namespace {
1767 /// A class for putting all blocks in a structured selection/loop in a
1768 /// spirv.mlir.selection/spirv.mlir.loop op.
1769 class ControlFlowStructurizer {
1770 public:
1771 #ifndef NDEBUG
1772  ControlFlowStructurizer(Location loc, uint32_t control,
1773  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1774  Block *merge, Block *cont,
1775  llvm::ScopedPrinter &logger)
1776  : location(loc), control(control), blockMergeInfo(mergeInfo),
1777  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1778  logger(logger) {}
1779 #else
1780  ControlFlowStructurizer(Location loc, uint32_t control,
1781  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1782  Block *merge, Block *cont)
1783  : location(loc), control(control), blockMergeInfo(mergeInfo),
1784  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1785 #endif
1786 
1787  /// Structurizes the loop at the given `headerBlock`.
1788  ///
1789  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1790  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1791  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1792  /// method will also update `mergeInfo` by remapping all blocks inside to the
1793  /// newly cloned ones inside structured control flow op's regions.
1794  LogicalResult structurize();
1795 
1796 private:
1797  /// Creates a new spirv.mlir.selection op at the beginning of the
1798  /// `mergeBlock`.
1799  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1800 
1801  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1802  spirv::LoopOp createLoopOp(uint32_t loopControl);
1803 
1804  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1805  void collectBlocksInConstruct();
1806 
1807  Location location;
1808  uint32_t control;
1809 
1810  spirv::BlockMergeInfoMap &blockMergeInfo;
1811 
1812  Block *headerBlock;
1813  Block *mergeBlock;
1814  Block *continueBlock; // nullptr for spirv.mlir.selection
1815 
1816  SetVector<Block *> constructBlocks;
1817 
1818 #ifndef NDEBUG
1819  /// A logger used to emit information during the deserialzation process.
1820  llvm::ScopedPrinter &logger;
1821 #endif
1822 };
1823 } // namespace
1824 
1825 spirv::SelectionOp
1826 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1827  // Create a builder and set the insertion point to the beginning of the
1828  // merge block so that the newly created SelectionOp will be inserted there.
1829  OpBuilder builder(&mergeBlock->front());
1830 
1831  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1832  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1833  selectionOp.addMergeBlock(builder);
1834 
1835  return selectionOp;
1836 }
1837 
1838 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1839  // Create a builder and set the insertion point to the beginning of the
1840  // merge block so that the newly created LoopOp will be inserted there.
1841  OpBuilder builder(&mergeBlock->front());
1842 
1843  auto control = static_cast<spirv::LoopControl>(loopControl);
1844  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1845  loopOp.addEntryAndMergeBlock(builder);
1846 
1847  return loopOp;
1848 }
1849 
1850 void ControlFlowStructurizer::collectBlocksInConstruct() {
1851  assert(constructBlocks.empty() && "expected empty constructBlocks");
1852 
1853  // Put the header block in the work list first.
1854  constructBlocks.insert(headerBlock);
1855 
1856  // For each item in the work list, add its successors excluding the merge
1857  // block.
1858  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1859  for (auto *successor : constructBlocks[i]->getSuccessors())
1860  if (successor != mergeBlock)
1861  constructBlocks.insert(successor);
1862  }
1863 }
1864 
1865 LogicalResult ControlFlowStructurizer::structurize() {
1866  Operation *op = nullptr;
1867  bool isLoop = continueBlock != nullptr;
1868  if (isLoop) {
1869  if (auto loopOp = createLoopOp(control))
1870  op = loopOp.getOperation();
1871  } else {
1872  if (auto selectionOp = createSelectionOp(control))
1873  op = selectionOp.getOperation();
1874  }
1875  if (!op)
1876  return failure();
1877  Region &body = op->getRegion(0);
1878 
1879  IRMapping mapper;
1880  // All references to the old merge block should be directed to the
1881  // selection/loop merge block in the SelectionOp/LoopOp's region.
1882  mapper.map(mergeBlock, &body.back());
1883 
1884  collectBlocksInConstruct();
1885 
1886  // We've identified all blocks belonging to the selection/loop's region. Now
1887  // need to "move" them into the selection/loop. Instead of really moving the
1888  // blocks, in the following we copy them and remap all values and branches.
1889  // This is because:
1890  // * Inserting a block into a region requires the block not in any region
1891  // before. But selections/loops can nest so we can create selection/loop ops
1892  // in a nested manner, which means some blocks may already be in a
1893  // selection/loop region when to be moved again.
1894  // * It's much trickier to fix up the branches into and out of the loop's
1895  // region: we need to treat not-moved blocks and moved blocks differently:
1896  // Not-moved blocks jumping to the loop header block need to jump to the
1897  // merge point containing the new loop op but not the loop continue block's
1898  // back edge. Moved blocks jumping out of the loop need to jump to the
1899  // merge block inside the loop region but not other not-moved blocks.
1900  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1901  // logic.
1902 
1903  // Create a corresponding block in the SelectionOp/LoopOp's region for each
1904  // block in this loop construct.
1905  OpBuilder builder(body);
1906  for (auto *block : constructBlocks) {
1907  // Create a block and insert it before the selection/loop merge block in the
1908  // SelectionOp/LoopOp's region.
1909  auto *newBlock = builder.createBlock(&body.back());
1910  mapper.map(block, newBlock);
1911  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1912  << " from block " << block << "\n");
1913  if (!isFnEntryBlock(block)) {
1914  for (BlockArgument blockArg : block->getArguments()) {
1915  auto newArg =
1916  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1917  mapper.map(blockArg, newArg);
1918  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1919  << blockArg << " to " << newArg << "\n");
1920  }
1921  } else {
1922  LLVM_DEBUG(logger.startLine()
1923  << "[cf] block " << block << " is a function entry block\n");
1924  }
1925 
1926  for (auto &op : *block)
1927  newBlock->push_back(op.clone(mapper));
1928  }
1929 
1930  // Go through all ops and remap the operands.
1931  auto remapOperands = [&](Operation *op) {
1932  for (auto &operand : op->getOpOperands())
1933  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1934  operand.set(mappedOp);
1935  for (auto &succOp : op->getBlockOperands())
1936  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1937  succOp.set(mappedOp);
1938  };
1939  for (auto &block : body)
1940  block.walk(remapOperands);
1941 
1942  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1943  // the selection/loop construct into its region. Next we need to fix the
1944  // connections between this new SelectionOp/LoopOp with existing blocks.
1945 
1946  // All existing incoming branches should go to the merge block, where the
1947  // SelectionOp/LoopOp resides right now.
1948  headerBlock->replaceAllUsesWith(mergeBlock);
1949 
1950  LLVM_DEBUG({
1951  logger.startLine() << "[cf] after cloning and fixing references:\n";
1952  headerBlock->getParentOp()->print(logger.getOStream());
1953  logger.startLine() << "\n";
1954  });
1955 
1956  if (isLoop) {
1957  if (!mergeBlock->args_empty()) {
1958  return mergeBlock->getParentOp()->emitError(
1959  "OpPhi in loop merge block unsupported");
1960  }
1961 
1962  // The loop header block may have block arguments. Since now we place the
1963  // loop op inside the old merge block, we need to make sure the old merge
1964  // block has the same block argument list.
1965  for (BlockArgument blockArg : headerBlock->getArguments())
1966  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1967 
1968  // If the loop header block has block arguments, make sure the spirv.Branch
1969  // op matches.
1970  SmallVector<Value, 4> blockArgs;
1971  if (!headerBlock->args_empty())
1972  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1973 
1974  // The loop entry block should have a unconditional branch jumping to the
1975  // loop header block.
1976  builder.setInsertionPointToEnd(&body.front());
1977  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1978  ArrayRef<Value>(blockArgs));
1979  }
1980 
1981  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1982  // cleaned up.
1983  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1984  // First we need to drop all operands' references inside all blocks. This is
1985  // needed because we can have blocks referencing SSA values from one another.
1986  for (auto *block : constructBlocks)
1987  block->dropAllReferences();
1988 
1989  // Check that whether some op in the to-be-erased blocks still has uses. Those
1990  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1991  // region. We cannot handle such cases given that once a value is sinked into
1992  // the SelectionOp/LoopOp's region, there is no escape for it:
1993  // SelectionOp/LooOp does not support yield values right now.
1994  for (auto *block : constructBlocks) {
1995  for (Operation &op : *block)
1996  if (!op.use_empty())
1997  return op.emitOpError(
1998  "failed control flow structurization: it has uses outside of the "
1999  "enclosing selection/loop construct");
2000  }
2001 
2002  // Then erase all old blocks.
2003  for (auto *block : constructBlocks) {
2004  // We've cloned all blocks belonging to this construct into the structured
2005  // control flow op's region. Among these blocks, some may compose another
2006  // selection/loop. If so, they will be recorded within blockMergeInfo.
2007  // We need to update the pointers there to the newly remapped ones so we can
2008  // continue structurizing them later.
2009  // TODO: The asserts in the following assumes input SPIR-V blob forms
2010  // correctly nested selection/loop constructs. We should relax this and
2011  // support error cases better.
2012  auto it = blockMergeInfo.find(block);
2013  if (it != blockMergeInfo.end()) {
2014  // Use the original location for nested selection/loop ops.
2015  Location loc = it->second.loc;
2016 
2017  Block *newHeader = mapper.lookupOrNull(block);
2018  if (!newHeader)
2019  return emitError(loc, "failed control flow structurization: nested "
2020  "loop header block should be remapped!");
2021 
2022  Block *newContinue = it->second.continueBlock;
2023  if (newContinue) {
2024  newContinue = mapper.lookupOrNull(newContinue);
2025  if (!newContinue)
2026  return emitError(loc, "failed control flow structurization: nested "
2027  "loop continue block should be remapped!");
2028  }
2029 
2030  Block *newMerge = it->second.mergeBlock;
2031  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2032  newMerge = mappedTo;
2033 
2034  // The iterator should be erased before adding a new entry into
2035  // blockMergeInfo to avoid iterator invalidation.
2036  blockMergeInfo.erase(it);
2037  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2038  newContinue);
2039  }
2040 
2041  // The structured selection/loop's entry block does not have arguments.
2042  // If the function's header block is also part of the structured control
2043  // flow, we cannot just simply erase it because it may contain arguments
2044  // matching the function signature and used by the cloned blocks.
2045  if (isFnEntryBlock(block)) {
2046  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2047  << " to only contain a spirv.Branch op\n");
2048  // Still keep the function entry block for the potential block arguments,
2049  // but replace all ops inside with a branch to the merge block.
2050  block->clear();
2051  builder.setInsertionPointToEnd(block);
2052  builder.create<spirv::BranchOp>(location, mergeBlock);
2053  } else {
2054  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2055  block->erase();
2056  }
2057  }
2058 
2059  LLVM_DEBUG(logger.startLine()
2060  << "[cf] after structurizing construct with header block "
2061  << headerBlock << ":\n"
2062  << *op << "\n");
2063 
2064  return success();
2065 }
2066 
2067 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2068  LLVM_DEBUG({
2069  logger.startLine()
2070  << "//----- [phi] start wiring up block arguments -----//\n";
2071  logger.indent();
2072  });
2073 
2074  OpBuilder::InsertionGuard guard(opBuilder);
2075 
2076  for (const auto &info : blockPhiInfo) {
2077  Block *block = info.first.first;
2078  Block *target = info.first.second;
2079  const BlockPhiInfo &phiInfo = info.second;
2080  LLVM_DEBUG({
2081  logger.startLine() << "[phi] block " << block << "\n";
2082  logger.startLine() << "[phi] before creating block argument:\n";
2083  block->getParentOp()->print(logger.getOStream());
2084  logger.startLine() << "\n";
2085  });
2086 
2087  // Set insertion point to before this block's terminator early because we
2088  // may materialize ops via getValue() call.
2089  auto *op = block->getTerminator();
2090  opBuilder.setInsertionPoint(op);
2091 
2092  SmallVector<Value, 4> blockArgs;
2093  blockArgs.reserve(phiInfo.size());
2094  for (uint32_t valueId : phiInfo) {
2095  if (Value value = getValue(valueId)) {
2096  blockArgs.push_back(value);
2097  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2098  << " id = " << valueId << "\n");
2099  } else {
2100  return emitError(unknownLoc, "OpPhi references undefined value!");
2101  }
2102  }
2103 
2104  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2105  // Replace the previous branch op with a new one with block arguments.
2106  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2107  blockArgs);
2108  branchOp.erase();
2109  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2110  assert((branchCondOp.getTrueBlock() == target ||
2111  branchCondOp.getFalseBlock() == target) &&
2112  "expected target to be either the true or false target");
2113  if (target == branchCondOp.getTrueTarget())
2114  opBuilder.create<spirv::BranchConditionalOp>(
2115  branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2116  branchCondOp.getFalseBlockArguments(),
2117  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2118  branchCondOp.getFalseTarget());
2119  else
2120  opBuilder.create<spirv::BranchConditionalOp>(
2121  branchCondOp.getLoc(), branchCondOp.getCondition(),
2122  branchCondOp.getTrueBlockArguments(), blockArgs,
2123  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2124  branchCondOp.getFalseBlock());
2125 
2126  branchCondOp.erase();
2127  } else {
2128  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2129  }
2130 
2131  LLVM_DEBUG({
2132  logger.startLine() << "[phi] after creating block argument:\n";
2133  block->getParentOp()->print(logger.getOStream());
2134  logger.startLine() << "\n";
2135  });
2136  }
2137  blockPhiInfo.clear();
2138 
2139  LLVM_DEBUG({
2140  logger.unindent();
2141  logger.startLine()
2142  << "//--- [phi] completed wiring up block arguments ---//\n";
2143  });
2144  return success();
2145 }
2146 
2147 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2148  LLVM_DEBUG({
2149  logger.startLine()
2150  << "//----- [cf] start structurizing control flow -----//\n";
2151  logger.indent();
2152  });
2153 
2154  while (!blockMergeInfo.empty()) {
2155  Block *headerBlock = blockMergeInfo.begin()->first;
2156  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2157 
2158  LLVM_DEBUG({
2159  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2160  headerBlock->print(logger.getOStream());
2161  logger.startLine() << "\n";
2162  });
2163 
2164  auto *mergeBlock = mergeInfo.mergeBlock;
2165  assert(mergeBlock && "merge block cannot be nullptr");
2166  if (!mergeBlock->args_empty())
2167  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2168  LLVM_DEBUG({
2169  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2170  mergeBlock->print(logger.getOStream());
2171  logger.startLine() << "\n";
2172  });
2173 
2174  auto *continueBlock = mergeInfo.continueBlock;
2175  LLVM_DEBUG(if (continueBlock) {
2176  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2177  continueBlock->print(logger.getOStream());
2178  logger.startLine() << "\n";
2179  });
2180  // Erase this case before calling into structurizer, who will update
2181  // blockMergeInfo.
2182  blockMergeInfo.erase(blockMergeInfo.begin());
2183  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2184  blockMergeInfo, headerBlock,
2185  mergeBlock, continueBlock
2186 #ifndef NDEBUG
2187  ,
2188  logger
2189 #endif
2190  );
2191  if (failed(structurizer.structurize()))
2192  return failure();
2193  }
2194 
2195  LLVM_DEBUG({
2196  logger.unindent();
2197  logger.startLine()
2198  << "//--- [cf] completed structurizing control flow ---//\n";
2199  });
2200  return success();
2201 }
2202 
2203 //===----------------------------------------------------------------------===//
2204 // Debug
2205 //===----------------------------------------------------------------------===//
2206 
2207 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2208  if (!debugLine)
2209  return unknownLoc;
2210 
2211  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2212  if (fileName.empty())
2213  fileName = "<unknown>";
2214  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2215  debugLine->column);
2216 }
2217 
2219 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2220  // According to SPIR-V spec:
2221  // "This location information applies to the instructions physically
2222  // following this instruction, up to the first occurrence of any of the
2223  // following: the next end of block, the next OpLine instruction, or the next
2224  // OpNoLine instruction."
2225  if (operands.size() != 3)
2226  return emitError(unknownLoc, "OpLine must have 3 operands");
2227  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2228  return success();
2229 }
2230 
2231 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2232 
2234 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2235  if (operands.size() < 2)
2236  return emitError(unknownLoc, "OpString needs at least 2 operands");
2237 
2238  if (!debugInfoMap.lookup(operands[0]).empty())
2239  return emitError(unknownLoc,
2240  "duplicate debug string found for result <id> ")
2241  << operands[0];
2242 
2243  unsigned wordIndex = 1;
2244  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2245  if (wordIndex != operands.size())
2246  return emitError(unknownLoc,
2247  "unexpected trailing words in OpString instruction");
2248 
2249  debugInfoMap[operands[0]] = debugString;
2250  return success();
2251 }
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
bool empty()
Definition: Block.h:146
Operation & back()
Definition: Block.h:150
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:65
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:307
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
void print(raw_ostream &os)
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
void push_back(Operation *op)
Definition: Block.h:147
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:691
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
BlockListType::iterator iterator
Definition: Region.h:52
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:228
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:169
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:298
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:538
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:807
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.