MLIR  18.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <optional>
33 
34 using namespace mlir;
35 
36 #define DEBUG_TYPE "spirv-deserialization"
37 
38 //===----------------------------------------------------------------------===//
39 // Utility Functions
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns true if the given `block` is a function entry block.
43 static inline bool isFnEntryBlock(Block *block) {
44  return block->isEntryBlock() &&
45  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // Deserializer Method Definitions
50 //===----------------------------------------------------------------------===//
51 
53  MLIRContext *context)
54  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55  module(createModuleOp()), opBuilder(module->getRegion())
56 #ifndef NDEBUG
57  ,
58  logger(llvm::dbgs())
59 #endif
60 {
61 }
62 
64  LLVM_DEBUG({
65  logger.resetIndent();
66  logger.startLine()
67  << "//+++---------- start deserialization ----------+++//\n";
68  });
69 
70  if (failed(processHeader()))
71  return failure();
72 
73  spirv::Opcode opcode = spirv::Opcode::OpNop;
74  ArrayRef<uint32_t> operands;
75  auto binarySize = binary.size();
76  while (curOffset < binarySize) {
77  // Slice the next instruction out and populate `opcode` and `operands`.
78  // Internally this also updates `curOffset`.
79  if (failed(sliceInstruction(opcode, operands)))
80  return failure();
81 
82  if (failed(processInstruction(opcode, operands)))
83  return failure();
84  }
85 
86  assert(curOffset == binarySize &&
87  "deserializer should never index beyond the binary end");
88 
89  for (auto &deferred : deferredInstructions) {
90  if (failed(processInstruction(deferred.first, deferred.second, false))) {
91  return failure();
92  }
93  }
94 
95  attachVCETriple();
96 
97  LLVM_DEBUG(logger.startLine()
98  << "//+++-------- completed deserialization --------+++//\n");
99  return success();
100 }
101 
103  return std::move(module);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Module structure
108 //===----------------------------------------------------------------------===//
109 
110 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111  OpBuilder builder(context);
112  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113  spirv::ModuleOp::build(builder, state);
114  return cast<spirv::ModuleOp>(Operation::create(state));
115 }
116 
117 LogicalResult spirv::Deserializer::processHeader() {
118  if (binary.size() < spirv::kHeaderWordCount)
119  return emitError(unknownLoc,
120  "SPIR-V binary module must have a 5-word header");
121 
122  if (binary[0] != spirv::kMagicNumber)
123  return emitError(unknownLoc, "incorrect magic number");
124 
125  // Version number bytes: 0 | major number | minor number | 0
126  uint32_t majorVersion = (binary[1] << 8) >> 24;
127  uint32_t minorVersion = (binary[1] << 16) >> 24;
128  if (majorVersion == 1) {
129  switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
131  case v: \
132  version = spirv::Version::V_1_##v; \
133  break
134 
135  MIN_VERSION_CASE(0);
136  MIN_VERSION_CASE(1);
137  MIN_VERSION_CASE(2);
138  MIN_VERSION_CASE(3);
139  MIN_VERSION_CASE(4);
140  MIN_VERSION_CASE(5);
141 #undef MIN_VERSION_CASE
142  default:
143  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
144  << minorVersion;
145  }
146  } else {
147  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
148  << majorVersion;
149  }
150 
151  // TODO: generator number, bound, schema
152  curOffset = spirv::kHeaderWordCount;
153  return success();
154 }
155 
157 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158  if (operands.size() != 1)
159  return emitError(unknownLoc, "OpMemoryModel must have one parameter");
160 
161  auto cap = spirv::symbolizeCapability(operands[0]);
162  if (!cap)
163  return emitError(unknownLoc, "unknown capability: ") << operands[0];
164 
165  capabilities.insert(*cap);
166  return success();
167 }
168 
169 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170  if (words.empty()) {
171  return emitError(
172  unknownLoc,
173  "OpExtension must have a literal string for the extension name");
174  }
175 
176  unsigned wordIndex = 0;
177  StringRef extName = decodeStringLiteral(words, wordIndex);
178  if (wordIndex != words.size())
179  return emitError(unknownLoc,
180  "unexpected trailing words in OpExtension instruction");
181  auto ext = spirv::symbolizeExtension(extName);
182  if (!ext)
183  return emitError(unknownLoc, "unknown extension: ") << extName;
184 
185  extensions.insert(*ext);
186  return success();
187 }
188 
190 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191  if (words.size() < 2) {
192  return emitError(unknownLoc,
193  "OpExtInstImport must have a result <id> and a literal "
194  "string for the extended instruction set name");
195  }
196 
197  unsigned wordIndex = 1;
198  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199  if (wordIndex != words.size()) {
200  return emitError(unknownLoc,
201  "unexpected trailing words in OpExtInstImport");
202  }
203  return success();
204 }
205 
206 void spirv::Deserializer::attachVCETriple() {
207  (*module)->setAttr(
208  spirv::ModuleOp::getVCETripleAttrName(),
209  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
210  extensions.getArrayRef(), context));
211 }
212 
214 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215  if (operands.size() != 2)
216  return emitError(unknownLoc, "OpMemoryModel must have two operands");
217 
218  (*module)->setAttr(
219  "addressing_model",
220  opBuilder.getAttr<spirv::AddressingModelAttr>(
221  static_cast<spirv::AddressingModel>(operands.front())));
222  (*module)->setAttr("memory_model",
223  opBuilder.getAttr<spirv::MemoryModelAttr>(
224  static_cast<spirv::MemoryModel>(operands.back())));
225 
226  return success();
227 }
228 
229 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
230  // TODO: This function should also be auto-generated. For now, since only a
231  // few decorations are processed/handled in a meaningful manner, going with a
232  // manual implementation.
233  if (words.size() < 2) {
234  return emitError(
235  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
236  }
237  auto decorationName =
238  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
239  if (decorationName.empty()) {
240  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
241  }
242  auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
243  auto symbol = opBuilder.getStringAttr(attrName);
244  switch (static_cast<spirv::Decoration>(words[1])) {
245  case spirv::Decoration::FPFastMathMode:
246  if (words.size() != 3) {
247  return emitError(unknownLoc, "OpDecorate with ")
248  << decorationName << " needs a single integer literal";
249  }
250  decorations[words[0]].set(
251  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
252  static_cast<FPFastMathMode>(words[2])));
253  break;
254  case spirv::Decoration::DescriptorSet:
255  case spirv::Decoration::Binding:
256  if (words.size() != 3) {
257  return emitError(unknownLoc, "OpDecorate with ")
258  << decorationName << " needs a single integer literal";
259  }
260  decorations[words[0]].set(
261  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
262  break;
263  case spirv::Decoration::BuiltIn:
264  if (words.size() != 3) {
265  return emitError(unknownLoc, "OpDecorate with ")
266  << decorationName << " needs a single integer literal";
267  }
268  decorations[words[0]].set(
269  symbol, opBuilder.getStringAttr(
270  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
271  break;
272  case spirv::Decoration::ArrayStride:
273  if (words.size() != 3) {
274  return emitError(unknownLoc, "OpDecorate with ")
275  << decorationName << " needs a single integer literal";
276  }
277  typeDecorations[words[0]] = words[2];
278  break;
279  case spirv::Decoration::LinkageAttributes: {
280  if (words.size() < 4) {
281  return emitError(unknownLoc, "OpDecorate with ")
282  << decorationName
283  << " needs at least 1 string and 1 integer literal";
284  }
285  // LinkageAttributes has two parameters ["linkageName", linkageType]
286  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
287  // "linkageName" is a stringliteral encoded as uint32_t,
288  // hence the size of name is variable length which results in words.size()
289  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
290  // 3 + ceildiv(strlen(name), 4).
291  unsigned wordIndex = 2;
292  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
293  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
294  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
295  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
296  linkageName, linkageTypeAttr);
297  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
298  break;
299  }
300  case spirv::Decoration::Aliased:
301  case spirv::Decoration::Block:
302  case spirv::Decoration::BufferBlock:
303  case spirv::Decoration::Flat:
304  case spirv::Decoration::NonReadable:
305  case spirv::Decoration::NonWritable:
306  case spirv::Decoration::NoPerspective:
307  case spirv::Decoration::NoSignedWrap:
308  case spirv::Decoration::NoUnsignedWrap:
309  case spirv::Decoration::RelaxedPrecision:
310  case spirv::Decoration::Restrict:
311  if (words.size() != 2) {
312  return emitError(unknownLoc, "OpDecoration with ")
313  << decorationName << "needs a single target <id>";
314  }
315  // Block decoration does not affect spirv.struct type, but is still stored
316  // for verification.
317  // TODO: Update StructType to contain this information since
318  // it is needed for many validation rules.
319  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
320  break;
321  case spirv::Decoration::Location:
322  case spirv::Decoration::SpecId:
323  if (words.size() != 3) {
324  return emitError(unknownLoc, "OpDecoration with ")
325  << decorationName << "needs a single integer literal";
326  }
327  decorations[words[0]].set(
328  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
329  break;
330  default:
331  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
332  }
333  return success();
334 }
335 
337 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
338  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
339  if (words.size() < 3) {
340  return emitError(unknownLoc,
341  "OpMemberDecorate must have at least 3 operands");
342  }
343 
344  auto decoration = static_cast<spirv::Decoration>(words[2]);
345  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
346  return emitError(unknownLoc,
347  " missing offset specification in OpMemberDecorate with "
348  "Offset decoration");
349  }
350  ArrayRef<uint32_t> decorationOperands;
351  if (words.size() > 3) {
352  decorationOperands = words.slice(3);
353  }
354  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
355  return success();
356 }
357 
358 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
359  if (words.size() < 3) {
360  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
361  }
362  unsigned wordIndex = 2;
363  auto name = decodeStringLiteral(words, wordIndex);
364  if (wordIndex != words.size()) {
365  return emitError(unknownLoc,
366  "unexpected trailing words in OpMemberName instruction");
367  }
368  memberNameMap[words[0]][words[1]] = name;
369  return success();
370 }
371 
373 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
374  if (curFunction) {
375  return emitError(unknownLoc, "found function inside function");
376  }
377 
378  // Get the result type
379  if (operands.size() != 4) {
380  return emitError(unknownLoc, "OpFunction must have 4 parameters");
381  }
382  Type resultType = getType(operands[0]);
383  if (!resultType) {
384  return emitError(unknownLoc, "undefined result type from <id> ")
385  << operands[0];
386  }
387 
388  uint32_t fnID = operands[1];
389  if (funcMap.count(fnID)) {
390  return emitError(unknownLoc, "duplicate function definition/declaration");
391  }
392 
393  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
394  if (!fnControl) {
395  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
396  }
397 
398  Type fnType = getType(operands[3]);
399  if (!fnType || !isa<FunctionType>(fnType)) {
400  return emitError(unknownLoc, "unknown function type from <id> ")
401  << operands[3];
402  }
403  auto functionType = cast<FunctionType>(fnType);
404 
405  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
406  (functionType.getNumResults() == 1 &&
407  functionType.getResult(0) != resultType)) {
408  return emitError(unknownLoc, "mismatch in function type ")
409  << functionType << " and return type " << resultType << " specified";
410  }
411 
412  std::string fnName = getFunctionSymbol(fnID);
413  auto funcOp = opBuilder.create<spirv::FuncOp>(
414  unknownLoc, fnName, functionType, fnControl.value());
415  // Processing other function attributes.
416  if (decorations.count(fnID)) {
417  for (auto attr : decorations[fnID].getAttrs()) {
418  funcOp->setAttr(attr.getName(), attr.getValue());
419  }
420  }
421  curFunction = funcMap[fnID] = funcOp;
422  auto *entryBlock = funcOp.addEntryBlock();
423  LLVM_DEBUG({
424  logger.startLine()
425  << "//===-------------------------------------------===//\n";
426  logger.startLine() << "[fn] name: " << fnName << "\n";
427  logger.startLine() << "[fn] type: " << fnType << "\n";
428  logger.startLine() << "[fn] ID: " << fnID << "\n";
429  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
430  logger.indent();
431  });
432 
433  // Parse the op argument instructions
434  if (functionType.getNumInputs()) {
435  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
436  auto argType = functionType.getInput(i);
437  spirv::Opcode opcode = spirv::Opcode::OpNop;
438  ArrayRef<uint32_t> operands;
439  if (failed(sliceInstruction(opcode, operands,
440  spirv::Opcode::OpFunctionParameter))) {
441  return failure();
442  }
443  if (opcode != spirv::Opcode::OpFunctionParameter) {
444  return emitError(
445  unknownLoc,
446  "missing OpFunctionParameter instruction for argument ")
447  << i;
448  }
449  if (operands.size() != 2) {
450  return emitError(
451  unknownLoc,
452  "expected result type and result <id> for OpFunctionParameter");
453  }
454  auto argDefinedType = getType(operands[0]);
455  if (!argDefinedType || argDefinedType != argType) {
456  return emitError(unknownLoc,
457  "mismatch in argument type between function type "
458  "definition ")
459  << functionType << " and argument type definition "
460  << argDefinedType << " at argument " << i;
461  }
462  if (getValue(operands[1])) {
463  return emitError(unknownLoc, "duplicate definition of result <id> ")
464  << operands[1];
465  }
466  auto argValue = funcOp.getArgument(i);
467  valueMap[operands[1]] = argValue;
468  }
469  }
470 
471  // entryBlock is needed to access the arguments, Once that is done, we can
472  // erase the block for functions with 'Import' LinkageAttributes, since these
473  // are essentially function declarations, so they have no body.
474  auto linkageAttr = funcOp.getLinkageAttributes();
475  auto hasImportLinkage =
476  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
477  spirv::LinkageType::Import);
478  if (hasImportLinkage)
479  funcOp.eraseBody();
480 
481  // RAII guard to reset the insertion point to the module's region after
482  // deserializing the body of this function.
483  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
484 
485  spirv::Opcode opcode = spirv::Opcode::OpNop;
486  ArrayRef<uint32_t> instOperands;
487 
488  // Special handling for the entry block. We need to make sure it starts with
489  // an OpLabel instruction. The entry block takes the same parameters as the
490  // function. All other blocks do not take any parameter. We have already
491  // created the entry block, here we need to register it to the correct label
492  // <id>.
493  if (failed(sliceInstruction(opcode, instOperands,
494  spirv::Opcode::OpFunctionEnd))) {
495  return failure();
496  }
497  if (opcode == spirv::Opcode::OpFunctionEnd) {
498  return processFunctionEnd(instOperands);
499  }
500  if (opcode != spirv::Opcode::OpLabel) {
501  return emitError(unknownLoc, "a basic block must start with OpLabel");
502  }
503  if (instOperands.size() != 1) {
504  return emitError(unknownLoc, "OpLabel should only have result <id>");
505  }
506  blockMap[instOperands[0]] = entryBlock;
507  if (failed(processLabel(instOperands))) {
508  return failure();
509  }
510 
511  // Then process all the other instructions in the function until we hit
512  // OpFunctionEnd.
513  while (succeeded(sliceInstruction(opcode, instOperands,
514  spirv::Opcode::OpFunctionEnd)) &&
515  opcode != spirv::Opcode::OpFunctionEnd) {
516  if (failed(processInstruction(opcode, instOperands))) {
517  return failure();
518  }
519  }
520  if (opcode != spirv::Opcode::OpFunctionEnd) {
521  return failure();
522  }
523 
524  return processFunctionEnd(instOperands);
525 }
526 
528 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
529  // Process OpFunctionEnd.
530  if (!operands.empty()) {
531  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
532  }
533 
534  // Wire up block arguments from OpPhi instructions.
535  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
536  // ops.
537  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
538  return failure();
539  }
540 
541  curBlock = nullptr;
542  curFunction = std::nullopt;
543 
544  LLVM_DEBUG({
545  logger.unindent();
546  logger.startLine()
547  << "//===-------------------------------------------===//\n";
548  });
549  return success();
550 }
551 
552 std::optional<std::pair<Attribute, Type>>
553 spirv::Deserializer::getConstant(uint32_t id) {
554  auto constIt = constantMap.find(id);
555  if (constIt == constantMap.end())
556  return std::nullopt;
557  return constIt->getSecond();
558 }
559 
560 std::optional<spirv::SpecConstOperationMaterializationInfo>
561 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
562  auto constIt = specConstOperationMap.find(id);
563  if (constIt == specConstOperationMap.end())
564  return std::nullopt;
565  return constIt->getSecond();
566 }
567 
568 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
569  auto funcName = nameMap.lookup(id).str();
570  if (funcName.empty()) {
571  funcName = "spirv_fn_" + std::to_string(id);
572  }
573  return funcName;
574 }
575 
576 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
577  auto constName = nameMap.lookup(id).str();
578  if (constName.empty()) {
579  constName = "spirv_spec_const_" + std::to_string(id);
580  }
581  return constName;
582 }
583 
584 spirv::SpecConstantOp
585 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
586  TypedAttr defaultValue) {
587  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
588  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
589  defaultValue);
590  if (decorations.count(resultID)) {
591  for (auto attr : decorations[resultID].getAttrs())
592  op->setAttr(attr.getName(), attr.getValue());
593  }
594  specConstMap[resultID] = op;
595  return op;
596 }
597 
599 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
600  unsigned wordIndex = 0;
601  if (operands.size() < 3) {
602  return emitError(
603  unknownLoc,
604  "OpVariable needs at least 3 operands, type, <id> and storage class");
605  }
606 
607  // Result Type.
608  auto type = getType(operands[wordIndex]);
609  if (!type) {
610  return emitError(unknownLoc, "unknown result type <id> : ")
611  << operands[wordIndex];
612  }
613  auto ptrType = dyn_cast<spirv::PointerType>(type);
614  if (!ptrType) {
615  return emitError(unknownLoc,
616  "expected a result type <id> to be a spirv.ptr, found : ")
617  << type;
618  }
619  wordIndex++;
620 
621  // Result <id>.
622  auto variableID = operands[wordIndex];
623  auto variableName = nameMap.lookup(variableID).str();
624  if (variableName.empty()) {
625  variableName = "spirv_var_" + std::to_string(variableID);
626  }
627  wordIndex++;
628 
629  // Storage class.
630  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
631  if (ptrType.getStorageClass() != storageClass) {
632  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
633  << type << " and that specified in OpVariable instruction : "
634  << stringifyStorageClass(storageClass);
635  }
636  wordIndex++;
637 
638  // Initializer.
639  FlatSymbolRefAttr initializer = nullptr;
640  if (wordIndex < operands.size()) {
641  auto initializerOp = getGlobalVariable(operands[wordIndex]);
642  if (!initializerOp) {
643  return emitError(unknownLoc, "unknown <id> ")
644  << operands[wordIndex] << "used as initializer";
645  }
646  wordIndex++;
647  initializer = SymbolRefAttr::get(initializerOp.getOperation());
648  }
649  if (wordIndex != operands.size()) {
650  return emitError(unknownLoc,
651  "found more operands than expected when deserializing "
652  "OpVariable instruction, only ")
653  << wordIndex << " of " << operands.size() << " processed";
654  }
655  auto loc = createFileLineColLoc(opBuilder);
656  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
657  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
658  initializer);
659 
660  // Decorations.
661  if (decorations.count(variableID)) {
662  for (auto attr : decorations[variableID].getAttrs())
663  varOp->setAttr(attr.getName(), attr.getValue());
664  }
665  globalVariableMap[variableID] = varOp;
666  return success();
667 }
668 
669 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
670  auto constInfo = getConstant(id);
671  if (!constInfo) {
672  return nullptr;
673  }
674  return dyn_cast<IntegerAttr>(constInfo->first);
675 }
676 
677 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
678  if (operands.size() < 2) {
679  return emitError(unknownLoc, "OpName needs at least 2 operands");
680  }
681  if (!nameMap.lookup(operands[0]).empty()) {
682  return emitError(unknownLoc, "duplicate name found for result <id> ")
683  << operands[0];
684  }
685  unsigned wordIndex = 1;
686  StringRef name = decodeStringLiteral(operands, wordIndex);
687  if (wordIndex != operands.size()) {
688  return emitError(unknownLoc,
689  "unexpected trailing words in OpName instruction");
690  }
691  nameMap[operands[0]] = name;
692  return success();
693 }
694 
695 //===----------------------------------------------------------------------===//
696 // Type
697 //===----------------------------------------------------------------------===//
698 
699 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
700  ArrayRef<uint32_t> operands) {
701  if (operands.empty()) {
702  return emitError(unknownLoc, "type instruction with opcode ")
703  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
704  }
705 
706  /// TODO: Types might be forward declared in some instructions and need to be
707  /// handled appropriately.
708  if (typeMap.count(operands[0])) {
709  return emitError(unknownLoc, "duplicate definition for result <id> ")
710  << operands[0];
711  }
712 
713  switch (opcode) {
714  case spirv::Opcode::OpTypeVoid:
715  if (operands.size() != 1)
716  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
717  typeMap[operands[0]] = opBuilder.getNoneType();
718  break;
719  case spirv::Opcode::OpTypeBool:
720  if (operands.size() != 1)
721  return emitError(unknownLoc, "OpTypeBool must have no parameters");
722  typeMap[operands[0]] = opBuilder.getI1Type();
723  break;
724  case spirv::Opcode::OpTypeInt: {
725  if (operands.size() != 3)
726  return emitError(
727  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
728 
729  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
730  // to preserve or validate.
731  // 0 indicates unsigned, or no signedness semantics
732  // 1 indicates signed semantics."
733  //
734  // So we cannot differentiate signless and unsigned integers; always use
735  // signless semantics for such cases.
736  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
737  : IntegerType::SignednessSemantics::Signless;
738  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
739  } break;
740  case spirv::Opcode::OpTypeFloat: {
741  if (operands.size() != 2)
742  return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
743 
744  Type floatTy;
745  switch (operands[1]) {
746  case 16:
747  floatTy = opBuilder.getF16Type();
748  break;
749  case 32:
750  floatTy = opBuilder.getF32Type();
751  break;
752  case 64:
753  floatTy = opBuilder.getF64Type();
754  break;
755  default:
756  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
757  << operands[1];
758  }
759  typeMap[operands[0]] = floatTy;
760  } break;
761  case spirv::Opcode::OpTypeVector: {
762  if (operands.size() != 3) {
763  return emitError(
764  unknownLoc,
765  "OpTypeVector must have element type and count parameters");
766  }
767  Type elementTy = getType(operands[1]);
768  if (!elementTy) {
769  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
770  << operands[1];
771  }
772  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
773  } break;
774  case spirv::Opcode::OpTypePointer: {
775  return processOpTypePointer(operands);
776  } break;
777  case spirv::Opcode::OpTypeArray:
778  return processArrayType(operands);
779  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
780  return processCooperativeMatrixTypeKHR(operands);
781  case spirv::Opcode::OpTypeCooperativeMatrixNV:
782  return processCooperativeMatrixTypeNV(operands);
783  case spirv::Opcode::OpTypeFunction:
784  return processFunctionType(operands);
785  case spirv::Opcode::OpTypeJointMatrixINTEL:
786  return processJointMatrixType(operands);
787  case spirv::Opcode::OpTypeImage:
788  return processImageType(operands);
789  case spirv::Opcode::OpTypeSampledImage:
790  return processSampledImageType(operands);
791  case spirv::Opcode::OpTypeRuntimeArray:
792  return processRuntimeArrayType(operands);
793  case spirv::Opcode::OpTypeStruct:
794  return processStructType(operands);
795  case spirv::Opcode::OpTypeMatrix:
796  return processMatrixType(operands);
797  default:
798  return emitError(unknownLoc, "unhandled type instruction");
799  }
800  return success();
801 }
802 
804 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
805  if (operands.size() != 3)
806  return emitError(unknownLoc, "OpTypePointer must have two parameters");
807 
808  auto pointeeType = getType(operands[2]);
809  if (!pointeeType)
810  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
811  << operands[2];
812 
813  uint32_t typePointerID = operands[0];
814  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
815  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
816 
817  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
818  deferredStructIt != std::end(deferredStructTypesInfos);) {
819  for (auto *unresolvedMemberIt =
820  std::begin(deferredStructIt->unresolvedMemberTypes);
821  unresolvedMemberIt !=
822  std::end(deferredStructIt->unresolvedMemberTypes);) {
823  if (unresolvedMemberIt->first == typePointerID) {
824  // The newly constructed pointer type can resolve one of the
825  // deferred struct type members; update the memberTypes list and
826  // clean the unresolvedMemberTypes list accordingly.
827  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
828  typeMap[typePointerID];
829  unresolvedMemberIt =
830  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
831  } else {
832  ++unresolvedMemberIt;
833  }
834  }
835 
836  if (deferredStructIt->unresolvedMemberTypes.empty()) {
837  // All deferred struct type members are now resolved, set the struct body.
838  auto structType = deferredStructIt->deferredStructType;
839 
840  assert(structType && "expected a spirv::StructType");
841  assert(structType.isIdentified() && "expected an indentified struct");
842 
843  if (failed(structType.trySetBody(
844  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
845  deferredStructIt->memberDecorationsInfo)))
846  return failure();
847 
848  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
849  } else {
850  ++deferredStructIt;
851  }
852  }
853 
854  return success();
855 }
856 
858 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
859  if (operands.size() != 3) {
860  return emitError(unknownLoc,
861  "OpTypeArray must have element type and count parameters");
862  }
863 
864  Type elementTy = getType(operands[1]);
865  if (!elementTy) {
866  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
867  << operands[1];
868  }
869 
870  unsigned count = 0;
871  // TODO: The count can also come frome a specialization constant.
872  auto countInfo = getConstant(operands[2]);
873  if (!countInfo) {
874  return emitError(unknownLoc, "OpTypeArray count <id> ")
875  << operands[2] << "can only come from normal constant right now";
876  }
877 
878  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
879  count = intVal.getValue().getZExtValue();
880  } else {
881  return emitError(unknownLoc, "OpTypeArray count must come from a "
882  "scalar integer constant instruction");
883  }
884 
885  typeMap[operands[0]] = spirv::ArrayType::get(
886  elementTy, count, typeDecorations.lookup(operands[0]));
887  return success();
888 }
889 
891 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
892  assert(!operands.empty() && "No operands for processing function type");
893  if (operands.size() == 1) {
894  return emitError(unknownLoc, "missing return type for OpTypeFunction");
895  }
896  auto returnType = getType(operands[1]);
897  if (!returnType) {
898  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
899  }
900  SmallVector<Type, 1> argTypes;
901  for (size_t i = 2, e = operands.size(); i < e; ++i) {
902  auto ty = getType(operands[i]);
903  if (!ty) {
904  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
905  }
906  argTypes.push_back(ty);
907  }
908  ArrayRef<Type> returnTypes;
909  if (!isVoidType(returnType)) {
910  returnTypes = llvm::ArrayRef(returnType);
911  }
912  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
913  return success();
914 }
915 
916 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
917  ArrayRef<uint32_t> operands) {
918  if (operands.size() != 6) {
919  return emitError(unknownLoc,
920  "OpTypeCooperativeMatrixKHR must have element type, "
921  "scope, row and column parameters, and use");
922  }
923 
924  Type elementTy = getType(operands[1]);
925  if (!elementTy) {
926  return emitError(unknownLoc,
927  "OpTypeCooperativeMatrixKHR references undefined <id> ")
928  << operands[1];
929  }
930 
931  std::optional<spirv::Scope> scope =
932  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
933  if (!scope) {
934  return emitError(
935  unknownLoc,
936  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
937  << operands[2];
938  }
939 
940  unsigned rows = getConstantInt(operands[3]).getInt();
941  unsigned columns = getConstantInt(operands[4]).getInt();
942 
943  std::optional<spirv::CooperativeMatrixUseKHR> use =
944  spirv::symbolizeCooperativeMatrixUseKHR(
945  getConstantInt(operands[5]).getInt());
946  if (!use) {
947  return emitError(
948  unknownLoc,
949  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
950  << operands[5];
951  }
952 
953  typeMap[operands[0]] =
954  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
955  return success();
956 }
957 
958 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
959  ArrayRef<uint32_t> operands) {
960  if (operands.size() != 5) {
961  return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element "
962  "type and row x column parameters");
963  }
964 
965  Type elementTy = getType(operands[1]);
966  if (!elementTy) {
967  return emitError(unknownLoc,
968  "OpTypeCooperativeMatrixNV references undefined <id> ")
969  << operands[1];
970  }
971 
972  std::optional<spirv::Scope> scope =
973  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
974  if (!scope) {
975  return emitError(
976  unknownLoc,
977  "OpTypeCooperativeMatrixNV references undefined scope <id> ")
978  << operands[2];
979  }
980 
981  unsigned rows = getConstantInt(operands[3]).getInt();
982  unsigned columns = getConstantInt(operands[4]).getInt();
983 
984  typeMap[operands[0]] =
985  spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns);
986  return success();
987 }
988 
990 spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
991  if (operands.size() != 6) {
992  return emitError(unknownLoc, "OpTypeJointMatrix must have element "
993  "type and row x column parameters");
994  }
995 
996  Type elementTy = getType(operands[1]);
997  if (!elementTy) {
998  return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
999  << operands[1];
1000  }
1001 
1002  auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
1003  if (!scope) {
1004  return emitError(unknownLoc,
1005  "OpTypeJointMatrix references undefined scope <id> ")
1006  << operands[5];
1007  }
1008  auto matrixLayout =
1009  spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
1010  if (!matrixLayout) {
1011  return emitError(unknownLoc,
1012  "OpTypeJointMatrix references undefined scope <id> ")
1013  << operands[4];
1014  }
1015  unsigned rows = getConstantInt(operands[2]).getInt();
1016  unsigned columns = getConstantInt(operands[3]).getInt();
1017 
1018  typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
1019  elementTy, scope.value(), rows, columns, matrixLayout.value());
1020  return success();
1021 }
1022 
1024 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1025  if (operands.size() != 2) {
1026  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1027  }
1028  Type memberType = getType(operands[1]);
1029  if (!memberType) {
1030  return emitError(unknownLoc,
1031  "OpTypeRuntimeArray references undefined <id> ")
1032  << operands[1];
1033  }
1034  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1035  memberType, typeDecorations.lookup(operands[0]));
1036  return success();
1037 }
1038 
1040 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1041  // TODO: Find a way to handle identified structs when debug info is stripped.
1042 
1043  if (operands.empty()) {
1044  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1045  }
1046 
1047  if (operands.size() == 1) {
1048  // Handle empty struct.
1049  typeMap[operands[0]] =
1050  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1051  return success();
1052  }
1053 
1054  // First element is operand ID, second element is member index in the struct.
1055  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1056  SmallVector<Type, 4> memberTypes;
1057 
1058  for (auto op : llvm::drop_begin(operands, 1)) {
1059  Type memberType = getType(op);
1060  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1061 
1062  if (!memberType && !typeForwardPtr)
1063  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1064  << op;
1065 
1066  if (!memberType)
1067  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1068 
1069  memberTypes.push_back(memberType);
1070  }
1071 
1074  if (memberDecorationMap.count(operands[0])) {
1075  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1076  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1077  if (allMemberDecorations.count(memberIndex)) {
1078  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1079  // Check for offset.
1080  if (memberDecoration.first == spirv::Decoration::Offset) {
1081  // If offset info is empty, resize to the number of members;
1082  if (offsetInfo.empty()) {
1083  offsetInfo.resize(memberTypes.size());
1084  }
1085  offsetInfo[memberIndex] = memberDecoration.second[0];
1086  } else {
1087  if (!memberDecoration.second.empty()) {
1088  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1089  memberDecoration.first,
1090  memberDecoration.second[0]);
1091  } else {
1092  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1093  memberDecoration.first, 0);
1094  }
1095  }
1096  }
1097  }
1098  }
1099  }
1100 
1101  uint32_t structID = operands[0];
1102  std::string structIdentifier = nameMap.lookup(structID).str();
1103 
1104  if (structIdentifier.empty()) {
1105  assert(unresolvedMemberTypes.empty() &&
1106  "didn't expect unresolved member types");
1107  typeMap[structID] =
1108  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1109  } else {
1110  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1111  typeMap[structID] = structTy;
1112 
1113  if (!unresolvedMemberTypes.empty())
1114  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1115  memberTypes, offsetInfo,
1116  memberDecorationsInfo});
1117  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1118  memberDecorationsInfo)))
1119  return failure();
1120  }
1121 
1122  // TODO: Update StructType to have member name as attribute as
1123  // well.
1124  return success();
1125 }
1126 
1128 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1129  if (operands.size() != 3) {
1130  // Three operands are needed: result_id, column_type, and column_count
1131  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1132  " (result_id, column_type, and column_count)");
1133  }
1134  // Matrix columns must be of vector type
1135  Type elementTy = getType(operands[1]);
1136  if (!elementTy) {
1137  return emitError(unknownLoc,
1138  "OpTypeMatrix references undefined column type.")
1139  << operands[1];
1140  }
1141 
1142  uint32_t colsCount = operands[2];
1143  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1144  return success();
1145 }
1146 
1148 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1149  if (operands.size() != 2)
1150  return emitError(unknownLoc,
1151  "OpTypeForwardPointer instruction must have two operands");
1152 
1153  typeForwardPointerIDs.insert(operands[0]);
1154  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1155  // instruction that defines the actual type.
1156 
1157  return success();
1158 }
1159 
1161 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1162  // TODO: Add support for Access Qualifier.
1163  if (operands.size() != 8)
1164  return emitError(
1165  unknownLoc,
1166  "OpTypeImage with non-eight operands are not supported yet");
1167 
1168  Type elementTy = getType(operands[1]);
1169  if (!elementTy)
1170  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1171  << operands[1];
1172 
1173  auto dim = spirv::symbolizeDim(operands[2]);
1174  if (!dim)
1175  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1176  << operands[2];
1177 
1178  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1179  if (!depthInfo)
1180  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1181  << operands[3];
1182 
1183  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1184  if (!arrayedInfo)
1185  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1186  << operands[4];
1187 
1188  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1189  if (!samplingInfo)
1190  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1191 
1192  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1193  if (!samplerUseInfo)
1194  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1195  << operands[6];
1196 
1197  auto format = spirv::symbolizeImageFormat(operands[7]);
1198  if (!format)
1199  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1200  << operands[7];
1201 
1202  typeMap[operands[0]] = spirv::ImageType::get(
1203  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1204  samplingInfo.value(), samplerUseInfo.value(), format.value());
1205  return success();
1206 }
1207 
1209 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1210  if (operands.size() != 2)
1211  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1212 
1213  Type elementTy = getType(operands[1]);
1214  if (!elementTy)
1215  return emitError(unknownLoc,
1216  "OpTypeSampledImage references undefined <id>: ")
1217  << operands[1];
1218 
1219  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1220  return success();
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // Constant
1225 //===----------------------------------------------------------------------===//
1226 
1227 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1228  bool isSpec) {
1229  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1230 
1231  if (operands.size() < 2) {
1232  return emitError(unknownLoc)
1233  << opname << " must have type <id> and result <id>";
1234  }
1235  if (operands.size() < 3) {
1236  return emitError(unknownLoc)
1237  << opname << " must have at least 1 more parameter";
1238  }
1239 
1240  Type resultType = getType(operands[0]);
1241  if (!resultType) {
1242  return emitError(unknownLoc, "undefined result type from <id> ")
1243  << operands[0];
1244  }
1245 
1246  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1247  if (bitwidth == 64) {
1248  if (operands.size() == 4) {
1249  return success();
1250  }
1251  return emitError(unknownLoc)
1252  << opname << " should have 2 parameters for 64-bit values";
1253  }
1254  if (bitwidth <= 32) {
1255  if (operands.size() == 3) {
1256  return success();
1257  }
1258 
1259  return emitError(unknownLoc)
1260  << opname
1261  << " should have 1 parameter for values with no more than 32 bits";
1262  }
1263  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1264  << bitwidth;
1265  };
1266 
1267  auto resultID = operands[1];
1268 
1269  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1270  auto bitwidth = intType.getWidth();
1271  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1272  return failure();
1273  }
1274 
1275  APInt value;
1276  if (bitwidth == 64) {
1277  // 64-bit integers are represented with two SPIR-V words. According to
1278  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1279  // literal’s low-order words appear first."
1280  struct DoubleWord {
1281  uint32_t word1;
1282  uint32_t word2;
1283  } words = {operands[2], operands[3]};
1284  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1285  } else if (bitwidth <= 32) {
1286  value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1287  }
1288 
1289  auto attr = opBuilder.getIntegerAttr(intType, value);
1290 
1291  if (isSpec) {
1292  createSpecConstant(unknownLoc, resultID, attr);
1293  } else {
1294  // For normal constants, we just record the attribute (and its type) for
1295  // later materialization at use sites.
1296  constantMap.try_emplace(resultID, attr, intType);
1297  }
1298 
1299  return success();
1300  }
1301 
1302  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1303  auto bitwidth = floatType.getWidth();
1304  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1305  return failure();
1306  }
1307 
1308  APFloat value(0.f);
1309  if (floatType.isF64()) {
1310  // Double values are represented with two SPIR-V words. According to
1311  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1312  // literal’s low-order words appear first."
1313  struct DoubleWord {
1314  uint32_t word1;
1315  uint32_t word2;
1316  } words = {operands[2], operands[3]};
1317  value = APFloat(llvm::bit_cast<double>(words));
1318  } else if (floatType.isF32()) {
1319  value = APFloat(llvm::bit_cast<float>(operands[2]));
1320  } else if (floatType.isF16()) {
1321  APInt data(16, operands[2]);
1322  value = APFloat(APFloat::IEEEhalf(), data);
1323  }
1324 
1325  auto attr = opBuilder.getFloatAttr(floatType, value);
1326  if (isSpec) {
1327  createSpecConstant(unknownLoc, resultID, attr);
1328  } else {
1329  // For normal constants, we just record the attribute (and its type) for
1330  // later materialization at use sites.
1331  constantMap.try_emplace(resultID, attr, floatType);
1332  }
1333 
1334  return success();
1335  }
1336 
1337  return emitError(unknownLoc, "OpConstant can only generate values of "
1338  "scalar integer or floating-point type");
1339 }
1340 
1341 LogicalResult spirv::Deserializer::processConstantBool(
1342  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1343  if (operands.size() != 2) {
1344  return emitError(unknownLoc, "Op")
1345  << (isSpec ? "Spec" : "") << "Constant"
1346  << (isTrue ? "True" : "False")
1347  << " must have type <id> and result <id>";
1348  }
1349 
1350  auto attr = opBuilder.getBoolAttr(isTrue);
1351  auto resultID = operands[1];
1352  if (isSpec) {
1353  createSpecConstant(unknownLoc, resultID, attr);
1354  } else {
1355  // For normal constants, we just record the attribute (and its type) for
1356  // later materialization at use sites.
1357  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1358  }
1359 
1360  return success();
1361 }
1362 
1364 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1365  if (operands.size() < 2) {
1366  return emitError(unknownLoc,
1367  "OpConstantComposite must have type <id> and result <id>");
1368  }
1369  if (operands.size() < 3) {
1370  return emitError(unknownLoc,
1371  "OpConstantComposite must have at least 1 parameter");
1372  }
1373 
1374  Type resultType = getType(operands[0]);
1375  if (!resultType) {
1376  return emitError(unknownLoc, "undefined result type from <id> ")
1377  << operands[0];
1378  }
1379 
1380  SmallVector<Attribute, 4> elements;
1381  elements.reserve(operands.size() - 2);
1382  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1383  auto elementInfo = getConstant(operands[i]);
1384  if (!elementInfo) {
1385  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1386  << operands[i] << " must come from a normal constant";
1387  }
1388  elements.push_back(elementInfo->first);
1389  }
1390 
1391  auto resultID = operands[1];
1392  if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1393  auto attr = DenseElementsAttr::get(vectorType, elements);
1394  // For normal constants, we just record the attribute (and its type) for
1395  // later materialization at use sites.
1396  constantMap.try_emplace(resultID, attr, resultType);
1397  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1398  auto attr = opBuilder.getArrayAttr(elements);
1399  constantMap.try_emplace(resultID, attr, resultType);
1400  } else {
1401  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1402  << resultType;
1403  }
1404 
1405  return success();
1406 }
1407 
1409 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1410  if (operands.size() < 2) {
1411  return emitError(unknownLoc,
1412  "OpConstantComposite must have type <id> and result <id>");
1413  }
1414  if (operands.size() < 3) {
1415  return emitError(unknownLoc,
1416  "OpConstantComposite must have at least 1 parameter");
1417  }
1418 
1419  Type resultType = getType(operands[0]);
1420  if (!resultType) {
1421  return emitError(unknownLoc, "undefined result type from <id> ")
1422  << operands[0];
1423  }
1424 
1425  auto resultID = operands[1];
1426  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1427 
1428  SmallVector<Attribute, 4> elements;
1429  elements.reserve(operands.size() - 2);
1430  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1431  auto elementInfo = getSpecConstant(operands[i]);
1432  elements.push_back(SymbolRefAttr::get(elementInfo));
1433  }
1434 
1435  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1436  unknownLoc, TypeAttr::get(resultType), symName,
1437  opBuilder.getArrayAttr(elements));
1438  specConstCompositeMap[resultID] = op;
1439 
1440  return success();
1441 }
1442 
1444 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1445  if (operands.size() < 3)
1446  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1447  "result <id>, and operand opcode");
1448 
1449  uint32_t resultTypeID = operands[0];
1450 
1451  if (!getType(resultTypeID))
1452  return emitError(unknownLoc, "undefined result type from <id> ")
1453  << resultTypeID;
1454 
1455  uint32_t resultID = operands[1];
1456  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1457  auto emplaceResult = specConstOperationMap.try_emplace(
1458  resultID,
1459  SpecConstOperationMaterializationInfo{
1460  enclosedOpcode, resultTypeID,
1461  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1462 
1463  if (!emplaceResult.second)
1464  return emitError(unknownLoc, "value with <id>: ")
1465  << resultID << " is probably defined before.";
1466 
1467  return success();
1468 }
1469 
1470 Value spirv::Deserializer::materializeSpecConstantOperation(
1471  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1472  ArrayRef<uint32_t> enclosedOpOperands) {
1473 
1474  Type resultType = getType(resultTypeID);
1475 
1476  // Instructions wrapped by OpSpecConstantOp need an ID for their
1477  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1478  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1479  // ID in that map is assigned to the result of the enclosed instruction. Note
1480  // that there is no need to update this fake ID since we only need to
1481  // reference the created Value for the enclosed op from the spv::YieldOp
1482  // created later in this method (both of which are the only values in their
1483  // region: the SpecConstantOperation's region). If we encounter another
1484  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1485  // previous Value assigned to it isn't visible in the current scope anyway.
1486  DenseMap<uint32_t, Value> newValueMap;
1487  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1488  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1489 
1490  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1491  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1492  enclosedOpResultTypeAndOperands.push_back(fakeID);
1493  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1494  enclosedOpOperands.end());
1495 
1496  // Process enclosed instruction before creating the enclosing
1497  // specConstantOperation (and its region). This way, references to constants,
1498  // global variables, and spec constants will be materialized outside the new
1499  // op's region. For more info, see Deserializer::getValue's implementation.
1500  if (failed(
1501  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1502  return Value();
1503 
1504  // Since the enclosed op is emitted in the current block, split it in a
1505  // separate new block.
1506  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1507 
1508  auto loc = createFileLineColLoc(opBuilder);
1509  auto specConstOperationOp =
1510  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1511 
1512  Region &body = specConstOperationOp.getBody();
1513  // Move the new block into SpecConstantOperation's body.
1514  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1515  Region::iterator(enclosedBlock));
1516  Block &block = body.back();
1517 
1518  // RAII guard to reset the insertion point to the module's region after
1519  // deserializing the body of the specConstantOperation.
1520  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1521  opBuilder.setInsertionPointToEnd(&block);
1522 
1523  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1524  return specConstOperationOp.getResult();
1525 }
1526 
1528 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1529  if (operands.size() != 2) {
1530  return emitError(unknownLoc,
1531  "OpConstantNull must have type <id> and result <id>");
1532  }
1533 
1534  Type resultType = getType(operands[0]);
1535  if (!resultType) {
1536  return emitError(unknownLoc, "undefined result type from <id> ")
1537  << operands[0];
1538  }
1539 
1540  auto resultID = operands[1];
1541  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1542  auto attr = opBuilder.getZeroAttr(resultType);
1543  // For normal constants, we just record the attribute (and its type) for
1544  // later materialization at use sites.
1545  constantMap.try_emplace(resultID, attr, resultType);
1546  return success();
1547  }
1548 
1549  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1550  << resultType;
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // Control flow
1555 //===----------------------------------------------------------------------===//
1556 
1557 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1558  if (auto *block = getBlock(id)) {
1559  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1560  << " @ " << block << "\n");
1561  return block;
1562  }
1563 
1564  // We don't know where this block will be placed finally (in a
1565  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1566  // function for now and sort out the proper place later.
1567  auto *block = curFunction->addBlock();
1568  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1569  << " @ " << block << "\n");
1570  return blockMap[id] = block;
1571 }
1572 
1573 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1574  if (!curBlock) {
1575  return emitError(unknownLoc, "OpBranch must appear inside a block");
1576  }
1577 
1578  if (operands.size() != 1) {
1579  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1580  }
1581 
1582  auto *target = getOrCreateBlock(operands[0]);
1583  auto loc = createFileLineColLoc(opBuilder);
1584  // The preceding instruction for the OpBranch instruction could be an
1585  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1586  // the same OpLine information.
1587  opBuilder.create<spirv::BranchOp>(loc, target);
1588 
1589  clearDebugLine();
1590  return success();
1591 }
1592 
1594 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1595  if (!curBlock) {
1596  return emitError(unknownLoc,
1597  "OpBranchConditional must appear inside a block");
1598  }
1599 
1600  if (operands.size() != 3 && operands.size() != 5) {
1601  return emitError(unknownLoc,
1602  "OpBranchConditional must have condition, true label, "
1603  "false label, and optionally two branch weights");
1604  }
1605 
1606  auto condition = getValue(operands[0]);
1607  auto *trueBlock = getOrCreateBlock(operands[1]);
1608  auto *falseBlock = getOrCreateBlock(operands[2]);
1609 
1610  std::optional<std::pair<uint32_t, uint32_t>> weights;
1611  if (operands.size() == 5) {
1612  weights = std::make_pair(operands[3], operands[4]);
1613  }
1614  // The preceding instruction for the OpBranchConditional instruction could be
1615  // an OpSelectionMerge instruction, in this case they will have the same
1616  // OpLine information.
1617  auto loc = createFileLineColLoc(opBuilder);
1618  opBuilder.create<spirv::BranchConditionalOp>(
1619  loc, condition, trueBlock,
1620  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1621  /*falseArguments=*/ArrayRef<Value>(), weights);
1622 
1623  clearDebugLine();
1624  return success();
1625 }
1626 
1627 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1628  if (!curFunction) {
1629  return emitError(unknownLoc, "OpLabel must appear inside a function");
1630  }
1631 
1632  if (operands.size() != 1) {
1633  return emitError(unknownLoc, "OpLabel should only have result <id>");
1634  }
1635 
1636  auto labelID = operands[0];
1637  // We may have forward declared this block.
1638  auto *block = getOrCreateBlock(labelID);
1639  LLVM_DEBUG(logger.startLine()
1640  << "[block] populating block " << block << "\n");
1641  // If we have seen this block, make sure it was just a forward declaration.
1642  assert(block->empty() && "re-deserialize the same block!");
1643 
1644  opBuilder.setInsertionPointToStart(block);
1645  blockMap[labelID] = curBlock = block;
1646 
1647  return success();
1648 }
1649 
1651 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1652  if (!curBlock) {
1653  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1654  }
1655 
1656  if (operands.size() < 2) {
1657  return emitError(
1658  unknownLoc,
1659  "OpSelectionMerge must specify merge target and selection control");
1660  }
1661 
1662  auto *mergeBlock = getOrCreateBlock(operands[0]);
1663  auto loc = createFileLineColLoc(opBuilder);
1664  auto selectionControl = operands[1];
1665 
1666  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1667  .second) {
1668  return emitError(
1669  unknownLoc,
1670  "a block cannot have more than one OpSelectionMerge instruction");
1671  }
1672 
1673  return success();
1674 }
1675 
1677 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1678  if (!curBlock) {
1679  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1680  }
1681 
1682  if (operands.size() < 3) {
1683  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1684  "continue target and loop control");
1685  }
1686 
1687  auto *mergeBlock = getOrCreateBlock(operands[0]);
1688  auto *continueBlock = getOrCreateBlock(operands[1]);
1689  auto loc = createFileLineColLoc(opBuilder);
1690  uint32_t loopControl = operands[2];
1691 
1692  if (!blockMergeInfo
1693  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1694  .second) {
1695  return emitError(
1696  unknownLoc,
1697  "a block cannot have more than one OpLoopMerge instruction");
1698  }
1699 
1700  return success();
1701 }
1702 
1703 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1704  if (!curBlock) {
1705  return emitError(unknownLoc, "OpPhi must appear in a block");
1706  }
1707 
1708  if (operands.size() < 4) {
1709  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1710  "and variable-parent pairs");
1711  }
1712 
1713  // Create a block argument for this OpPhi instruction.
1714  Type blockArgType = getType(operands[0]);
1715  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1716  valueMap[operands[1]] = blockArg;
1717  LLVM_DEBUG(logger.startLine()
1718  << "[phi] created block argument " << blockArg
1719  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1720 
1721  // For each (value, predecessor) pair, insert the value to the predecessor's
1722  // blockPhiInfo entry so later we can fix the block argument there.
1723  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1724  uint32_t value = operands[i];
1725  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1726  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1727  blockPhiInfo[predecessorTargetPair].push_back(value);
1728  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1729  << " with arg id = " << value << "\n");
1730  }
1731 
1732  return success();
1733 }
1734 
1735 namespace {
1736 /// A class for putting all blocks in a structured selection/loop in a
1737 /// spirv.mlir.selection/spirv.mlir.loop op.
1738 class ControlFlowStructurizer {
1739 public:
1740 #ifndef NDEBUG
1741  ControlFlowStructurizer(Location loc, uint32_t control,
1742  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1743  Block *merge, Block *cont,
1744  llvm::ScopedPrinter &logger)
1745  : location(loc), control(control), blockMergeInfo(mergeInfo),
1746  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1747  logger(logger) {}
1748 #else
1749  ControlFlowStructurizer(Location loc, uint32_t control,
1750  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1751  Block *merge, Block *cont)
1752  : location(loc), control(control), blockMergeInfo(mergeInfo),
1753  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1754 #endif
1755 
1756  /// Structurizes the loop at the given `headerBlock`.
1757  ///
1758  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1759  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1760  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1761  /// method will also update `mergeInfo` by remapping all blocks inside to the
1762  /// newly cloned ones inside structured control flow op's regions.
1763  LogicalResult structurize();
1764 
1765 private:
1766  /// Creates a new spirv.mlir.selection op at the beginning of the
1767  /// `mergeBlock`.
1768  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1769 
1770  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1771  spirv::LoopOp createLoopOp(uint32_t loopControl);
1772 
1773  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1774  void collectBlocksInConstruct();
1775 
1776  Location location;
1777  uint32_t control;
1778 
1779  spirv::BlockMergeInfoMap &blockMergeInfo;
1780 
1781  Block *headerBlock;
1782  Block *mergeBlock;
1783  Block *continueBlock; // nullptr for spirv.mlir.selection
1784 
1785  SetVector<Block *> constructBlocks;
1786 
1787 #ifndef NDEBUG
1788  /// A logger used to emit information during the deserialzation process.
1789  llvm::ScopedPrinter &logger;
1790 #endif
1791 };
1792 } // namespace
1793 
1794 spirv::SelectionOp
1795 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1796  // Create a builder and set the insertion point to the beginning of the
1797  // merge block so that the newly created SelectionOp will be inserted there.
1798  OpBuilder builder(&mergeBlock->front());
1799 
1800  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1801  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1802  selectionOp.addMergeBlock();
1803 
1804  return selectionOp;
1805 }
1806 
1807 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1808  // Create a builder and set the insertion point to the beginning of the
1809  // merge block so that the newly created LoopOp will be inserted there.
1810  OpBuilder builder(&mergeBlock->front());
1811 
1812  auto control = static_cast<spirv::LoopControl>(loopControl);
1813  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1814  loopOp.addEntryAndMergeBlock();
1815 
1816  return loopOp;
1817 }
1818 
1819 void ControlFlowStructurizer::collectBlocksInConstruct() {
1820  assert(constructBlocks.empty() && "expected empty constructBlocks");
1821 
1822  // Put the header block in the work list first.
1823  constructBlocks.insert(headerBlock);
1824 
1825  // For each item in the work list, add its successors excluding the merge
1826  // block.
1827  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1828  for (auto *successor : constructBlocks[i]->getSuccessors())
1829  if (successor != mergeBlock)
1830  constructBlocks.insert(successor);
1831  }
1832 }
1833 
1834 LogicalResult ControlFlowStructurizer::structurize() {
1835  Operation *op = nullptr;
1836  bool isLoop = continueBlock != nullptr;
1837  if (isLoop) {
1838  if (auto loopOp = createLoopOp(control))
1839  op = loopOp.getOperation();
1840  } else {
1841  if (auto selectionOp = createSelectionOp(control))
1842  op = selectionOp.getOperation();
1843  }
1844  if (!op)
1845  return failure();
1846  Region &body = op->getRegion(0);
1847 
1848  IRMapping mapper;
1849  // All references to the old merge block should be directed to the
1850  // selection/loop merge block in the SelectionOp/LoopOp's region.
1851  mapper.map(mergeBlock, &body.back());
1852 
1853  collectBlocksInConstruct();
1854 
1855  // We've identified all blocks belonging to the selection/loop's region. Now
1856  // need to "move" them into the selection/loop. Instead of really moving the
1857  // blocks, in the following we copy them and remap all values and branches.
1858  // This is because:
1859  // * Inserting a block into a region requires the block not in any region
1860  // before. But selections/loops can nest so we can create selection/loop ops
1861  // in a nested manner, which means some blocks may already be in a
1862  // selection/loop region when to be moved again.
1863  // * It's much trickier to fix up the branches into and out of the loop's
1864  // region: we need to treat not-moved blocks and moved blocks differently:
1865  // Not-moved blocks jumping to the loop header block need to jump to the
1866  // merge point containing the new loop op but not the loop continue block's
1867  // back edge. Moved blocks jumping out of the loop need to jump to the
1868  // merge block inside the loop region but not other not-moved blocks.
1869  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1870  // logic.
1871 
1872  // Create a corresponding block in the SelectionOp/LoopOp's region for each
1873  // block in this loop construct.
1874  OpBuilder builder(body);
1875  for (auto *block : constructBlocks) {
1876  // Create a block and insert it before the selection/loop merge block in the
1877  // SelectionOp/LoopOp's region.
1878  auto *newBlock = builder.createBlock(&body.back());
1879  mapper.map(block, newBlock);
1880  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1881  << " from block " << block << "\n");
1882  if (!isFnEntryBlock(block)) {
1883  for (BlockArgument blockArg : block->getArguments()) {
1884  auto newArg =
1885  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1886  mapper.map(blockArg, newArg);
1887  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1888  << blockArg << " to " << newArg << "\n");
1889  }
1890  } else {
1891  LLVM_DEBUG(logger.startLine()
1892  << "[cf] block " << block << " is a function entry block\n");
1893  }
1894 
1895  for (auto &op : *block)
1896  newBlock->push_back(op.clone(mapper));
1897  }
1898 
1899  // Go through all ops and remap the operands.
1900  auto remapOperands = [&](Operation *op) {
1901  for (auto &operand : op->getOpOperands())
1902  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1903  operand.set(mappedOp);
1904  for (auto &succOp : op->getBlockOperands())
1905  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1906  succOp.set(mappedOp);
1907  };
1908  for (auto &block : body)
1909  block.walk(remapOperands);
1910 
1911  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1912  // the selection/loop construct into its region. Next we need to fix the
1913  // connections between this new SelectionOp/LoopOp with existing blocks.
1914 
1915  // All existing incoming branches should go to the merge block, where the
1916  // SelectionOp/LoopOp resides right now.
1917  headerBlock->replaceAllUsesWith(mergeBlock);
1918 
1919  LLVM_DEBUG({
1920  logger.startLine() << "[cf] after cloning and fixing references:\n";
1921  headerBlock->getParentOp()->print(logger.getOStream());
1922  logger.startLine() << "\n";
1923  });
1924 
1925  if (isLoop) {
1926  if (!mergeBlock->args_empty()) {
1927  return mergeBlock->getParentOp()->emitError(
1928  "OpPhi in loop merge block unsupported");
1929  }
1930 
1931  // The loop header block may have block arguments. Since now we place the
1932  // loop op inside the old merge block, we need to make sure the old merge
1933  // block has the same block argument list.
1934  for (BlockArgument blockArg : headerBlock->getArguments())
1935  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1936 
1937  // If the loop header block has block arguments, make sure the spirv.Branch
1938  // op matches.
1939  SmallVector<Value, 4> blockArgs;
1940  if (!headerBlock->args_empty())
1941  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1942 
1943  // The loop entry block should have a unconditional branch jumping to the
1944  // loop header block.
1945  builder.setInsertionPointToEnd(&body.front());
1946  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1947  ArrayRef<Value>(blockArgs));
1948  }
1949 
1950  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1951  // cleaned up.
1952  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1953  // First we need to drop all operands' references inside all blocks. This is
1954  // needed because we can have blocks referencing SSA values from one another.
1955  for (auto *block : constructBlocks)
1956  block->dropAllReferences();
1957 
1958  // Check that whether some op in the to-be-erased blocks still has uses. Those
1959  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1960  // region. We cannot handle such cases given that once a value is sinked into
1961  // the SelectionOp/LoopOp's region, there is no escape for it:
1962  // SelectionOp/LooOp does not support yield values right now.
1963  for (auto *block : constructBlocks) {
1964  for (Operation &op : *block)
1965  if (!op.use_empty())
1966  return op.emitOpError(
1967  "failed control flow structurization: it has uses outside of the "
1968  "enclosing selection/loop construct");
1969  }
1970 
1971  // Then erase all old blocks.
1972  for (auto *block : constructBlocks) {
1973  // We've cloned all blocks belonging to this construct into the structured
1974  // control flow op's region. Among these blocks, some may compose another
1975  // selection/loop. If so, they will be recorded within blockMergeInfo.
1976  // We need to update the pointers there to the newly remapped ones so we can
1977  // continue structurizing them later.
1978  // TODO: The asserts in the following assumes input SPIR-V blob forms
1979  // correctly nested selection/loop constructs. We should relax this and
1980  // support error cases better.
1981  auto it = blockMergeInfo.find(block);
1982  if (it != blockMergeInfo.end()) {
1983  // Use the original location for nested selection/loop ops.
1984  Location loc = it->second.loc;
1985 
1986  Block *newHeader = mapper.lookupOrNull(block);
1987  if (!newHeader)
1988  return emitError(loc, "failed control flow structurization: nested "
1989  "loop header block should be remapped!");
1990 
1991  Block *newContinue = it->second.continueBlock;
1992  if (newContinue) {
1993  newContinue = mapper.lookupOrNull(newContinue);
1994  if (!newContinue)
1995  return emitError(loc, "failed control flow structurization: nested "
1996  "loop continue block should be remapped!");
1997  }
1998 
1999  Block *newMerge = it->second.mergeBlock;
2000  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2001  newMerge = mappedTo;
2002 
2003  // The iterator should be erased before adding a new entry into
2004  // blockMergeInfo to avoid iterator invalidation.
2005  blockMergeInfo.erase(it);
2006  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2007  newContinue);
2008  }
2009 
2010  // The structured selection/loop's entry block does not have arguments.
2011  // If the function's header block is also part of the structured control
2012  // flow, we cannot just simply erase it because it may contain arguments
2013  // matching the function signature and used by the cloned blocks.
2014  if (isFnEntryBlock(block)) {
2015  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2016  << " to only contain a spirv.Branch op\n");
2017  // Still keep the function entry block for the potential block arguments,
2018  // but replace all ops inside with a branch to the merge block.
2019  block->clear();
2020  builder.setInsertionPointToEnd(block);
2021  builder.create<spirv::BranchOp>(location, mergeBlock);
2022  } else {
2023  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2024  block->erase();
2025  }
2026  }
2027 
2028  LLVM_DEBUG(logger.startLine()
2029  << "[cf] after structurizing construct with header block "
2030  << headerBlock << ":\n"
2031  << *op << "\n");
2032 
2033  return success();
2034 }
2035 
2036 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2037  LLVM_DEBUG({
2038  logger.startLine()
2039  << "//----- [phi] start wiring up block arguments -----//\n";
2040  logger.indent();
2041  });
2042 
2043  OpBuilder::InsertionGuard guard(opBuilder);
2044 
2045  for (const auto &info : blockPhiInfo) {
2046  Block *block = info.first.first;
2047  Block *target = info.first.second;
2048  const BlockPhiInfo &phiInfo = info.second;
2049  LLVM_DEBUG({
2050  logger.startLine() << "[phi] block " << block << "\n";
2051  logger.startLine() << "[phi] before creating block argument:\n";
2052  block->getParentOp()->print(logger.getOStream());
2053  logger.startLine() << "\n";
2054  });
2055 
2056  // Set insertion point to before this block's terminator early because we
2057  // may materialize ops via getValue() call.
2058  auto *op = block->getTerminator();
2059  opBuilder.setInsertionPoint(op);
2060 
2061  SmallVector<Value, 4> blockArgs;
2062  blockArgs.reserve(phiInfo.size());
2063  for (uint32_t valueId : phiInfo) {
2064  if (Value value = getValue(valueId)) {
2065  blockArgs.push_back(value);
2066  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2067  << " id = " << valueId << "\n");
2068  } else {
2069  return emitError(unknownLoc, "OpPhi references undefined value!");
2070  }
2071  }
2072 
2073  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2074  // Replace the previous branch op with a new one with block arguments.
2075  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2076  blockArgs);
2077  branchOp.erase();
2078  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2079  assert((branchCondOp.getTrueBlock() == target ||
2080  branchCondOp.getFalseBlock() == target) &&
2081  "expected target to be either the true or false target");
2082  if (target == branchCondOp.getTrueTarget())
2083  opBuilder.create<spirv::BranchConditionalOp>(
2084  branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2085  branchCondOp.getFalseBlockArguments(),
2086  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2087  branchCondOp.getFalseTarget());
2088  else
2089  opBuilder.create<spirv::BranchConditionalOp>(
2090  branchCondOp.getLoc(), branchCondOp.getCondition(),
2091  branchCondOp.getTrueBlockArguments(), blockArgs,
2092  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2093  branchCondOp.getFalseBlock());
2094 
2095  branchCondOp.erase();
2096  } else {
2097  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2098  }
2099 
2100  LLVM_DEBUG({
2101  logger.startLine() << "[phi] after creating block argument:\n";
2102  block->getParentOp()->print(logger.getOStream());
2103  logger.startLine() << "\n";
2104  });
2105  }
2106  blockPhiInfo.clear();
2107 
2108  LLVM_DEBUG({
2109  logger.unindent();
2110  logger.startLine()
2111  << "//--- [phi] completed wiring up block arguments ---//\n";
2112  });
2113  return success();
2114 }
2115 
2116 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2117  LLVM_DEBUG({
2118  logger.startLine()
2119  << "//----- [cf] start structurizing control flow -----//\n";
2120  logger.indent();
2121  });
2122 
2123  while (!blockMergeInfo.empty()) {
2124  Block *headerBlock = blockMergeInfo.begin()->first;
2125  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2126 
2127  LLVM_DEBUG({
2128  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2129  headerBlock->print(logger.getOStream());
2130  logger.startLine() << "\n";
2131  });
2132 
2133  auto *mergeBlock = mergeInfo.mergeBlock;
2134  assert(mergeBlock && "merge block cannot be nullptr");
2135  if (!mergeBlock->args_empty())
2136  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2137  LLVM_DEBUG({
2138  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2139  mergeBlock->print(logger.getOStream());
2140  logger.startLine() << "\n";
2141  });
2142 
2143  auto *continueBlock = mergeInfo.continueBlock;
2144  LLVM_DEBUG(if (continueBlock) {
2145  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2146  continueBlock->print(logger.getOStream());
2147  logger.startLine() << "\n";
2148  });
2149  // Erase this case before calling into structurizer, who will update
2150  // blockMergeInfo.
2151  blockMergeInfo.erase(blockMergeInfo.begin());
2152  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2153  blockMergeInfo, headerBlock,
2154  mergeBlock, continueBlock
2155 #ifndef NDEBUG
2156  ,
2157  logger
2158 #endif
2159  );
2160  if (failed(structurizer.structurize()))
2161  return failure();
2162  }
2163 
2164  LLVM_DEBUG({
2165  logger.unindent();
2166  logger.startLine()
2167  << "//--- [cf] completed structurizing control flow ---//\n";
2168  });
2169  return success();
2170 }
2171 
2172 //===----------------------------------------------------------------------===//
2173 // Debug
2174 //===----------------------------------------------------------------------===//
2175 
2176 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2177  if (!debugLine)
2178  return unknownLoc;
2179 
2180  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2181  if (fileName.empty())
2182  fileName = "<unknown>";
2183  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2184  debugLine->column);
2185 }
2186 
2188 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2189  // According to SPIR-V spec:
2190  // "This location information applies to the instructions physically
2191  // following this instruction, up to the first occurrence of any of the
2192  // following: the next end of block, the next OpLine instruction, or the next
2193  // OpNoLine instruction."
2194  if (operands.size() != 3)
2195  return emitError(unknownLoc, "OpLine must have 3 operands");
2196  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2197  return success();
2198 }
2199 
2200 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2201 
2203 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2204  if (operands.size() < 2)
2205  return emitError(unknownLoc, "OpString needs at least 2 operands");
2206 
2207  if (!debugInfoMap.lookup(operands[0]).empty())
2208  return emitError(unknownLoc,
2209  "duplicate debug string found for result <id> ")
2210  << operands[0];
2211 
2212  unsigned wordIndex = 1;
2213  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2214  if (wordIndex != operands.size())
2215  return emitError(unknownLoc,
2216  "unexpected trailing words in OpString instruction");
2217 
2218  debugInfoMap[operands[0]] = debugString;
2219  return success();
2220 }
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents an argument of a Block.
Definition: Value.h:315
Location getLoc() const
Return the location for this argument.
Definition: Value.h:330
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:141
Operation & back()
Definition: Block.h:145
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:60
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:302
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
void print(raw_ostream &os)
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator begin()
Definition: Block.h:136
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
void push_back(Operation *op)
Definition: Block.h:142
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:831
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:686
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:674
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
BlockListType::iterator iterator
Definition: Region.h:52
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:297
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:229
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:170
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:363
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:546
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:603
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:872
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.