MLIR  22.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context,
54  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55  module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56 #ifndef NDEBUG
57  ,
58  logger(llvm::dbgs())
59 #endif
60 {
61 }
62 
64  LLVM_DEBUG({
65  logger.resetIndent();
66  logger.startLine()
67  << "//+++---------- start deserialization ----------+++//\n";
68  });
69 
70  if (failed(processHeader()))
71  return failure();
72 
73  spirv::Opcode opcode = spirv::Opcode::OpNop;
74  ArrayRef<uint32_t> operands;
75  auto binarySize = binary.size();
76  while (curOffset < binarySize) {
77  // Slice the next instruction out and populate `opcode` and `operands`.
78  // Internally this also updates `curOffset`.
79  if (failed(sliceInstruction(opcode, operands)))
80  return failure();
81 
82  if (failed(processInstruction(opcode, operands)))
83  return failure();
84  }
85 
86  assert(curOffset == binarySize &&
87  "deserializer should never index beyond the binary end");
88 
89  for (auto &deferred : deferredInstructions) {
90  if (failed(processInstruction(deferred.first, deferred.second, false))) {
91  return failure();
92  }
93  }
94 
95  attachVCETriple();
96 
97  LLVM_DEBUG(logger.startLine()
98  << "//+++-------- completed deserialization --------+++//\n");
99  return success();
100 }
101 
103  return std::move(module);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Module structure
108 //===----------------------------------------------------------------------===//
109 
110 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111  OpBuilder builder(context);
112  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113  spirv::ModuleOp::build(builder, state);
114  return cast<spirv::ModuleOp>(Operation::create(state));
115 }
116 
117 LogicalResult spirv::Deserializer::processHeader() {
118  if (binary.size() < spirv::kHeaderWordCount)
119  return emitError(unknownLoc,
120  "SPIR-V binary module must have a 5-word header");
121 
122  if (binary[0] != spirv::kMagicNumber)
123  return emitError(unknownLoc, "incorrect magic number");
124 
125  // Version number bytes: 0 | major number | minor number | 0
126  uint32_t majorVersion = (binary[1] << 8) >> 24;
127  uint32_t minorVersion = (binary[1] << 16) >> 24;
128  if (majorVersion == 1) {
129  switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
131  case v: \
132  version = spirv::Version::V_1_##v; \
133  break
134 
135  MIN_VERSION_CASE(0);
136  MIN_VERSION_CASE(1);
137  MIN_VERSION_CASE(2);
138  MIN_VERSION_CASE(3);
139  MIN_VERSION_CASE(4);
140  MIN_VERSION_CASE(5);
141  MIN_VERSION_CASE(6);
142 #undef MIN_VERSION_CASE
143  default:
144  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
145  << minorVersion;
146  }
147  } else {
148  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
149  << majorVersion;
150  }
151 
152  // TODO: generator number, bound, schema
153  curOffset = spirv::kHeaderWordCount;
154  return success();
155 }
156 
157 LogicalResult
158 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159  if (operands.size() != 1)
160  return emitError(unknownLoc, "OpCapability must have one parameter");
161 
162  auto cap = spirv::symbolizeCapability(operands[0]);
163  if (!cap)
164  return emitError(unknownLoc, "unknown capability: ") << operands[0];
165 
166  capabilities.insert(*cap);
167  return success();
168 }
169 
170 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
171  if (words.empty()) {
172  return emitError(
173  unknownLoc,
174  "OpExtension must have a literal string for the extension name");
175  }
176 
177  unsigned wordIndex = 0;
178  StringRef extName = decodeStringLiteral(words, wordIndex);
179  if (wordIndex != words.size())
180  return emitError(unknownLoc,
181  "unexpected trailing words in OpExtension instruction");
182  auto ext = spirv::symbolizeExtension(extName);
183  if (!ext)
184  return emitError(unknownLoc, "unknown extension: ") << extName;
185 
186  extensions.insert(*ext);
187  return success();
188 }
189 
190 LogicalResult
191 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192  if (words.size() < 2) {
193  return emitError(unknownLoc,
194  "OpExtInstImport must have a result <id> and a literal "
195  "string for the extended instruction set name");
196  }
197 
198  unsigned wordIndex = 1;
199  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
200  if (wordIndex != words.size()) {
201  return emitError(unknownLoc,
202  "unexpected trailing words in OpExtInstImport");
203  }
204  return success();
205 }
206 
207 void spirv::Deserializer::attachVCETriple() {
208  (*module)->setAttr(
209  spirv::ModuleOp::getVCETripleAttrName(),
210  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
211  extensions.getArrayRef(), context));
212 }
213 
214 LogicalResult
215 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216  if (operands.size() != 2)
217  return emitError(unknownLoc, "OpMemoryModel must have two operands");
218 
219  (*module)->setAttr(
220  module->getAddressingModelAttrName(),
221  opBuilder.getAttr<spirv::AddressingModelAttr>(
222  static_cast<spirv::AddressingModel>(operands.front())));
223 
224  (*module)->setAttr(module->getMemoryModelAttrName(),
225  opBuilder.getAttr<spirv::MemoryModelAttr>(
226  static_cast<spirv::MemoryModel>(operands.back())));
227 
228  return success();
229 }
230 
231 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
233  Location loc, OpBuilder &opBuilder,
235  StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236  if (words.size() != 4) {
237  return emitError(loc, "OpDecoration with ")
238  << decorationName << "needs a cache control integer literal and a "
239  << cacheControlKind << " cache control literal";
240  }
241  unsigned cacheLevel = words[2];
242  auto cacheControlAttr = static_cast<EnumTy>(words[3]);
243  auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245  if (auto attrList =
246  llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
247  llvm::append_range(attrs, attrList);
248  attrs.push_back(value);
249  decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
250  return success();
251 }
252 
253 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
254  // TODO: This function should also be auto-generated. For now, since only a
255  // few decorations are processed/handled in a meaningful manner, going with a
256  // manual implementation.
257  if (words.size() < 2) {
258  return emitError(
259  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
260  }
261  auto decorationName =
262  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
263  if (decorationName.empty()) {
264  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
265  }
266  auto symbol = getSymbolDecoration(decorationName);
267  switch (static_cast<spirv::Decoration>(words[1])) {
268  case spirv::Decoration::FPFastMathMode:
269  if (words.size() != 3) {
270  return emitError(unknownLoc, "OpDecorate with ")
271  << decorationName << " needs a single integer literal";
272  }
273  decorations[words[0]].set(
274  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275  static_cast<FPFastMathMode>(words[2])));
276  break;
277  case spirv::Decoration::FPRoundingMode:
278  if (words.size() != 3) {
279  return emitError(unknownLoc, "OpDecorate with ")
280  << decorationName << " needs a single integer literal";
281  }
282  decorations[words[0]].set(
283  symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284  static_cast<FPRoundingMode>(words[2])));
285  break;
286  case spirv::Decoration::DescriptorSet:
287  case spirv::Decoration::Binding:
288  if (words.size() != 3) {
289  return emitError(unknownLoc, "OpDecorate with ")
290  << decorationName << " needs a single integer literal";
291  }
292  decorations[words[0]].set(
293  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
294  break;
295  case spirv::Decoration::BuiltIn:
296  if (words.size() != 3) {
297  return emitError(unknownLoc, "OpDecorate with ")
298  << decorationName << " needs a single integer literal";
299  }
300  decorations[words[0]].set(
301  symbol, opBuilder.getStringAttr(
302  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
303  break;
304  case spirv::Decoration::ArrayStride:
305  if (words.size() != 3) {
306  return emitError(unknownLoc, "OpDecorate with ")
307  << decorationName << " needs a single integer literal";
308  }
309  typeDecorations[words[0]] = words[2];
310  break;
311  case spirv::Decoration::LinkageAttributes: {
312  if (words.size() < 4) {
313  return emitError(unknownLoc, "OpDecorate with ")
314  << decorationName
315  << " needs at least 1 string and 1 integer literal";
316  }
317  // LinkageAttributes has two parameters ["linkageName", linkageType]
318  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
319  // "linkageName" is a stringliteral encoded as uint32_t,
320  // hence the size of name is variable length which results in words.size()
321  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
322  // 3 + ceildiv(strlen(name), 4).
323  unsigned wordIndex = 2;
324  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
325  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
326  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
327  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328  StringAttr::get(context, linkageName), linkageTypeAttr);
329  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
330  break;
331  }
332  case spirv::Decoration::Aliased:
333  case spirv::Decoration::AliasedPointer:
334  case spirv::Decoration::Block:
335  case spirv::Decoration::BufferBlock:
336  case spirv::Decoration::Flat:
337  case spirv::Decoration::NonReadable:
338  case spirv::Decoration::NonWritable:
339  case spirv::Decoration::NoPerspective:
340  case spirv::Decoration::NoSignedWrap:
341  case spirv::Decoration::NoUnsignedWrap:
342  case spirv::Decoration::RelaxedPrecision:
343  case spirv::Decoration::Restrict:
344  case spirv::Decoration::RestrictPointer:
345  case spirv::Decoration::NoContraction:
346  case spirv::Decoration::Constant:
347  case spirv::Decoration::Invariant:
348  case spirv::Decoration::Patch:
349  if (words.size() != 2) {
350  return emitError(unknownLoc, "OpDecoration with ")
351  << decorationName << "needs a single target <id>";
352  }
353  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
354  break;
355  case spirv::Decoration::Location:
356  case spirv::Decoration::SpecId:
357  if (words.size() != 3) {
358  return emitError(unknownLoc, "OpDecoration with ")
359  << decorationName << "needs a single integer literal";
360  }
361  decorations[words[0]].set(
362  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
363  break;
364  case spirv::Decoration::CacheControlLoadINTEL: {
365  LogicalResult res = deserializeCacheControlDecoration<
366  CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
368  "load");
369  if (failed(res))
370  return res;
371  break;
372  }
373  case spirv::Decoration::CacheControlStoreINTEL: {
374  LogicalResult res = deserializeCacheControlDecoration<
375  CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
377  "store");
378  if (failed(res))
379  return res;
380  break;
381  }
382  default:
383  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
384  }
385  return success();
386 }
387 
388 LogicalResult
389 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
390  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
391  if (words.size() < 3) {
392  return emitError(unknownLoc,
393  "OpMemberDecorate must have at least 3 operands");
394  }
395 
396  auto decoration = static_cast<spirv::Decoration>(words[2]);
397  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
398  return emitError(unknownLoc,
399  " missing offset specification in OpMemberDecorate with "
400  "Offset decoration");
401  }
402  ArrayRef<uint32_t> decorationOperands;
403  if (words.size() > 3) {
404  decorationOperands = words.slice(3);
405  }
406  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
407  return success();
408 }
409 
410 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
411  if (words.size() < 3) {
412  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
413  }
414  unsigned wordIndex = 2;
415  auto name = decodeStringLiteral(words, wordIndex);
416  if (wordIndex != words.size()) {
417  return emitError(unknownLoc,
418  "unexpected trailing words in OpMemberName instruction");
419  }
420  memberNameMap[words[0]][words[1]] = name;
421  return success();
422 }
423 
424 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
425  uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
426  if (!decorations.contains(argID)) {
427  argAttrs[argIndex] = DictionaryAttr::get(context, {});
428  return success();
429  }
430 
431  spirv::DecorationAttr foundDecorationAttr;
432  for (NamedAttribute decAttr : decorations[argID]) {
433  for (auto decoration :
434  {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435  spirv::Decoration::AliasedPointer,
436  spirv::Decoration::RestrictPointer}) {
437 
438  if (decAttr.getName() !=
439  getSymbolDecoration(stringifyDecoration(decoration)))
440  continue;
441 
442  if (foundDecorationAttr)
443  return emitError(unknownLoc,
444  "more than one Aliased/Restrict decorations for "
445  "function argument with result <id> ")
446  << argID;
447 
448  foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
449  break;
450  }
451 
452  if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
453  spirv::Decoration::RelaxedPrecision))) {
454  // TODO: Current implementation supports only one decoration per function
455  // parameter so RelaxedPrecision cannot be applied at the same time as,
456  // for example, Aliased/Restrict/etc. This should be relaxed to allow any
457  // combination of decoration allowed by the spec to be supported.
458  if (foundDecorationAttr)
459  return emitError(unknownLoc, "already found a decoration for function "
460  "argument with result <id> ")
461  << argID;
462 
463  foundDecorationAttr = spirv::DecorationAttr::get(
464  context, spirv::Decoration::RelaxedPrecision);
465  }
466  }
467 
468  if (!foundDecorationAttr)
469  return emitError(unknownLoc, "unimplemented decoration support for "
470  "function argument with result <id> ")
471  << argID;
472 
473  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
474  foundDecorationAttr);
475  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
476  return success();
477 }
478 
479 LogicalResult
480 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
481  if (curFunction) {
482  return emitError(unknownLoc, "found function inside function");
483  }
484 
485  // Get the result type
486  if (operands.size() != 4) {
487  return emitError(unknownLoc, "OpFunction must have 4 parameters");
488  }
489  Type resultType = getType(operands[0]);
490  if (!resultType) {
491  return emitError(unknownLoc, "undefined result type from <id> ")
492  << operands[0];
493  }
494 
495  uint32_t fnID = operands[1];
496  if (funcMap.count(fnID)) {
497  return emitError(unknownLoc, "duplicate function definition/declaration");
498  }
499 
500  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
501  if (!fnControl) {
502  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
503  }
504 
505  Type fnType = getType(operands[3]);
506  if (!fnType || !isa<FunctionType>(fnType)) {
507  return emitError(unknownLoc, "unknown function type from <id> ")
508  << operands[3];
509  }
510  auto functionType = cast<FunctionType>(fnType);
511 
512  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
513  (functionType.getNumResults() == 1 &&
514  functionType.getResult(0) != resultType)) {
515  return emitError(unknownLoc, "mismatch in function type ")
516  << functionType << " and return type " << resultType << " specified";
517  }
518 
519  std::string fnName = getFunctionSymbol(fnID);
520  auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
521  functionType, fnControl.value());
522  // Processing other function attributes.
523  if (decorations.count(fnID)) {
524  for (auto attr : decorations[fnID].getAttrs()) {
525  funcOp->setAttr(attr.getName(), attr.getValue());
526  }
527  }
528  curFunction = funcMap[fnID] = funcOp;
529  auto *entryBlock = funcOp.addEntryBlock();
530  LLVM_DEBUG({
531  logger.startLine()
532  << "//===-------------------------------------------===//\n";
533  logger.startLine() << "[fn] name: " << fnName << "\n";
534  logger.startLine() << "[fn] type: " << fnType << "\n";
535  logger.startLine() << "[fn] ID: " << fnID << "\n";
536  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
537  logger.indent();
538  });
539 
540  SmallVector<Attribute> argAttrs;
541  argAttrs.resize(functionType.getNumInputs());
542 
543  // Parse the op argument instructions
544  if (functionType.getNumInputs()) {
545  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
546  auto argType = functionType.getInput(i);
547  spirv::Opcode opcode = spirv::Opcode::OpNop;
548  ArrayRef<uint32_t> operands;
549  if (failed(sliceInstruction(opcode, operands,
550  spirv::Opcode::OpFunctionParameter))) {
551  return failure();
552  }
553  if (opcode != spirv::Opcode::OpFunctionParameter) {
554  return emitError(
555  unknownLoc,
556  "missing OpFunctionParameter instruction for argument ")
557  << i;
558  }
559  if (operands.size() != 2) {
560  return emitError(
561  unknownLoc,
562  "expected result type and result <id> for OpFunctionParameter");
563  }
564  auto argDefinedType = getType(operands[0]);
565  if (!argDefinedType || argDefinedType != argType) {
566  return emitError(unknownLoc,
567  "mismatch in argument type between function type "
568  "definition ")
569  << functionType << " and argument type definition "
570  << argDefinedType << " at argument " << i;
571  }
572  if (getValue(operands[1])) {
573  return emitError(unknownLoc, "duplicate definition of result <id> ")
574  << operands[1];
575  }
576  if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
577  return failure();
578  }
579 
580  auto argValue = funcOp.getArgument(i);
581  valueMap[operands[1]] = argValue;
582  }
583  }
584 
585  if (llvm::any_of(argAttrs, [](Attribute attr) {
586  auto argAttr = cast<DictionaryAttr>(attr);
587  return !argAttr.empty();
588  }))
589  funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
590 
591  // entryBlock is needed to access the arguments, Once that is done, we can
592  // erase the block for functions with 'Import' LinkageAttributes, since these
593  // are essentially function declarations, so they have no body.
594  auto linkageAttr = funcOp.getLinkageAttributes();
595  auto hasImportLinkage =
596  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
597  spirv::LinkageType::Import);
598  if (hasImportLinkage)
599  funcOp.eraseBody();
600 
601  // RAII guard to reset the insertion point to the module's region after
602  // deserializing the body of this function.
603  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
604 
605  spirv::Opcode opcode = spirv::Opcode::OpNop;
606  ArrayRef<uint32_t> instOperands;
607 
608  // Special handling for the entry block. We need to make sure it starts with
609  // an OpLabel instruction. The entry block takes the same parameters as the
610  // function. All other blocks do not take any parameter. We have already
611  // created the entry block, here we need to register it to the correct label
612  // <id>.
613  if (failed(sliceInstruction(opcode, instOperands,
614  spirv::Opcode::OpFunctionEnd))) {
615  return failure();
616  }
617  if (opcode == spirv::Opcode::OpFunctionEnd) {
618  return processFunctionEnd(instOperands);
619  }
620  if (opcode != spirv::Opcode::OpLabel) {
621  return emitError(unknownLoc, "a basic block must start with OpLabel");
622  }
623  if (instOperands.size() != 1) {
624  return emitError(unknownLoc, "OpLabel should only have result <id>");
625  }
626  blockMap[instOperands[0]] = entryBlock;
627  if (failed(processLabel(instOperands))) {
628  return failure();
629  }
630 
631  // Then process all the other instructions in the function until we hit
632  // OpFunctionEnd.
633  while (succeeded(sliceInstruction(opcode, instOperands,
634  spirv::Opcode::OpFunctionEnd)) &&
635  opcode != spirv::Opcode::OpFunctionEnd) {
636  if (failed(processInstruction(opcode, instOperands))) {
637  return failure();
638  }
639  }
640  if (opcode != spirv::Opcode::OpFunctionEnd) {
641  return failure();
642  }
643 
644  return processFunctionEnd(instOperands);
645 }
646 
647 LogicalResult
648 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
649  // Process OpFunctionEnd.
650  if (!operands.empty()) {
651  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
652  }
653 
654  // Wire up block arguments from OpPhi instructions.
655  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
656  // ops.
657  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
658  return failure();
659  }
660 
661  curBlock = nullptr;
662  curFunction = std::nullopt;
663 
664  LLVM_DEBUG({
665  logger.unindent();
666  logger.startLine()
667  << "//===-------------------------------------------===//\n";
668  });
669  return success();
670 }
671 
672 std::optional<std::pair<Attribute, Type>>
673 spirv::Deserializer::getConstant(uint32_t id) {
674  auto constIt = constantMap.find(id);
675  if (constIt == constantMap.end())
676  return std::nullopt;
677  return constIt->getSecond();
678 }
679 
680 std::optional<std::pair<Attribute, Type>>
681 spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
682  if (auto it = constantCompositeReplicateMap.find(id);
683  it != constantCompositeReplicateMap.end())
684  return it->second;
685  return std::nullopt;
686 }
687 
688 std::optional<spirv::SpecConstOperationMaterializationInfo>
689 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
690  auto constIt = specConstOperationMap.find(id);
691  if (constIt == specConstOperationMap.end())
692  return std::nullopt;
693  return constIt->getSecond();
694 }
695 
696 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
697  auto funcName = nameMap.lookup(id).str();
698  if (funcName.empty()) {
699  funcName = "spirv_fn_" + std::to_string(id);
700  }
701  return funcName;
702 }
703 
704 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
705  auto constName = nameMap.lookup(id).str();
706  if (constName.empty()) {
707  constName = "spirv_spec_const_" + std::to_string(id);
708  }
709  return constName;
710 }
711 
712 spirv::SpecConstantOp
713 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
714  TypedAttr defaultValue) {
715  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
716  auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
717  defaultValue);
718  if (decorations.count(resultID)) {
719  for (auto attr : decorations[resultID].getAttrs())
720  op->setAttr(attr.getName(), attr.getValue());
721  }
722  specConstMap[resultID] = op;
723  return op;
724 }
725 
726 LogicalResult
727 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
728  unsigned wordIndex = 0;
729  if (operands.size() < 3) {
730  return emitError(
731  unknownLoc,
732  "OpVariable needs at least 3 operands, type, <id> and storage class");
733  }
734 
735  // Result Type.
736  auto type = getType(operands[wordIndex]);
737  if (!type) {
738  return emitError(unknownLoc, "unknown result type <id> : ")
739  << operands[wordIndex];
740  }
741  auto ptrType = dyn_cast<spirv::PointerType>(type);
742  if (!ptrType) {
743  return emitError(unknownLoc,
744  "expected a result type <id> to be a spirv.ptr, found : ")
745  << type;
746  }
747  wordIndex++;
748 
749  // Result <id>.
750  auto variableID = operands[wordIndex];
751  auto variableName = nameMap.lookup(variableID).str();
752  if (variableName.empty()) {
753  variableName = "spirv_var_" + std::to_string(variableID);
754  }
755  wordIndex++;
756 
757  // Storage class.
758  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
759  if (ptrType.getStorageClass() != storageClass) {
760  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
761  << type << " and that specified in OpVariable instruction : "
762  << stringifyStorageClass(storageClass);
763  }
764  wordIndex++;
765 
766  // Initializer.
767  FlatSymbolRefAttr initializer = nullptr;
768 
769  if (wordIndex < operands.size()) {
770  Operation *op = nullptr;
771 
772  if (auto initOp = getGlobalVariable(operands[wordIndex]))
773  op = initOp;
774  else if (auto initOp = getSpecConstant(operands[wordIndex]))
775  op = initOp;
776  else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
777  op = initOp;
778  else
779  return emitError(unknownLoc, "unknown <id> ")
780  << operands[wordIndex] << "used as initializer";
781 
782  initializer = SymbolRefAttr::get(op);
783  wordIndex++;
784  }
785  if (wordIndex != operands.size()) {
786  return emitError(unknownLoc,
787  "found more operands than expected when deserializing "
788  "OpVariable instruction, only ")
789  << wordIndex << " of " << operands.size() << " processed";
790  }
791  auto loc = createFileLineColLoc(opBuilder);
792  auto varOp = spirv::GlobalVariableOp::create(
793  opBuilder, loc, TypeAttr::get(type),
794  opBuilder.getStringAttr(variableName), initializer);
795 
796  // Decorations.
797  if (decorations.count(variableID)) {
798  for (auto attr : decorations[variableID].getAttrs())
799  varOp->setAttr(attr.getName(), attr.getValue());
800  }
801  globalVariableMap[variableID] = varOp;
802  return success();
803 }
804 
805 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
806  auto constInfo = getConstant(id);
807  if (!constInfo) {
808  return nullptr;
809  }
810  return dyn_cast<IntegerAttr>(constInfo->first);
811 }
812 
813 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
814  if (operands.size() < 2) {
815  return emitError(unknownLoc, "OpName needs at least 2 operands");
816  }
817  if (!nameMap.lookup(operands[0]).empty()) {
818  return emitError(unknownLoc, "duplicate name found for result <id> ")
819  << operands[0];
820  }
821  unsigned wordIndex = 1;
822  StringRef name = decodeStringLiteral(operands, wordIndex);
823  if (wordIndex != operands.size()) {
824  return emitError(unknownLoc,
825  "unexpected trailing words in OpName instruction");
826  }
827  nameMap[operands[0]] = name;
828  return success();
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // Type
833 //===----------------------------------------------------------------------===//
834 
835 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
836  ArrayRef<uint32_t> operands) {
837  if (operands.empty()) {
838  return emitError(unknownLoc, "type instruction with opcode ")
839  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
840  }
841 
842  /// TODO: Types might be forward declared in some instructions and need to be
843  /// handled appropriately.
844  if (typeMap.count(operands[0])) {
845  return emitError(unknownLoc, "duplicate definition for result <id> ")
846  << operands[0];
847  }
848 
849  switch (opcode) {
850  case spirv::Opcode::OpTypeVoid:
851  if (operands.size() != 1)
852  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
853  typeMap[operands[0]] = opBuilder.getNoneType();
854  break;
855  case spirv::Opcode::OpTypeBool:
856  if (operands.size() != 1)
857  return emitError(unknownLoc, "OpTypeBool must have no parameters");
858  typeMap[operands[0]] = opBuilder.getI1Type();
859  break;
860  case spirv::Opcode::OpTypeInt: {
861  if (operands.size() != 3)
862  return emitError(
863  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
864 
865  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
866  // to preserve or validate.
867  // 0 indicates unsigned, or no signedness semantics
868  // 1 indicates signed semantics."
869  //
870  // So we cannot differentiate signless and unsigned integers; always use
871  // signless semantics for such cases.
872  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
873  : IntegerType::SignednessSemantics::Signless;
874  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
875  } break;
876  case spirv::Opcode::OpTypeFloat: {
877  if (operands.size() != 2 && operands.size() != 3)
878  return emitError(unknownLoc,
879  "OpTypeFloat expects either 2 operands (type, bitwidth) "
880  "or 3 operands (type, bitwidth, encoding), but got ")
881  << operands.size();
882  uint32_t bitWidth = operands[1];
883 
884  Type floatTy;
885  switch (bitWidth) {
886  case 16:
887  floatTy = opBuilder.getF16Type();
888  break;
889  case 32:
890  floatTy = opBuilder.getF32Type();
891  break;
892  case 64:
893  floatTy = opBuilder.getF64Type();
894  break;
895  default:
896  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
897  << bitWidth;
898  }
899 
900  if (operands.size() == 3) {
901  if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
902  return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
903  << operands[2];
904  if (bitWidth != 16)
905  return emitError(unknownLoc,
906  "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
907  << bitWidth << " (expected 16)";
908  floatTy = opBuilder.getBF16Type();
909  }
910 
911  typeMap[operands[0]] = floatTy;
912  } break;
913  case spirv::Opcode::OpTypeVector: {
914  if (operands.size() != 3) {
915  return emitError(
916  unknownLoc,
917  "OpTypeVector must have element type and count parameters");
918  }
919  Type elementTy = getType(operands[1]);
920  if (!elementTy) {
921  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
922  << operands[1];
923  }
924  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
925  } break;
926  case spirv::Opcode::OpTypePointer: {
927  return processOpTypePointer(operands);
928  } break;
929  case spirv::Opcode::OpTypeArray:
930  return processArrayType(operands);
931  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
932  return processCooperativeMatrixTypeKHR(operands);
933  case spirv::Opcode::OpTypeFunction:
934  return processFunctionType(operands);
935  case spirv::Opcode::OpTypeImage:
936  return processImageType(operands);
937  case spirv::Opcode::OpTypeSampledImage:
938  return processSampledImageType(operands);
939  case spirv::Opcode::OpTypeRuntimeArray:
940  return processRuntimeArrayType(operands);
941  case spirv::Opcode::OpTypeStruct:
942  return processStructType(operands);
943  case spirv::Opcode::OpTypeMatrix:
944  return processMatrixType(operands);
945  case spirv::Opcode::OpTypeTensorARM:
946  return processTensorARMType(operands);
947  default:
948  return emitError(unknownLoc, "unhandled type instruction");
949  }
950  return success();
951 }
952 
953 LogicalResult
954 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
955  if (operands.size() != 3)
956  return emitError(unknownLoc, "OpTypePointer must have two parameters");
957 
958  auto pointeeType = getType(operands[2]);
959  if (!pointeeType)
960  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
961  << operands[2];
962 
963  uint32_t typePointerID = operands[0];
964  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
965  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
966 
967  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
968  deferredStructIt != std::end(deferredStructTypesInfos);) {
969  for (auto *unresolvedMemberIt =
970  std::begin(deferredStructIt->unresolvedMemberTypes);
971  unresolvedMemberIt !=
972  std::end(deferredStructIt->unresolvedMemberTypes);) {
973  if (unresolvedMemberIt->first == typePointerID) {
974  // The newly constructed pointer type can resolve one of the
975  // deferred struct type members; update the memberTypes list and
976  // clean the unresolvedMemberTypes list accordingly.
977  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
978  typeMap[typePointerID];
979  unresolvedMemberIt =
980  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
981  } else {
982  ++unresolvedMemberIt;
983  }
984  }
985 
986  if (deferredStructIt->unresolvedMemberTypes.empty()) {
987  // All deferred struct type members are now resolved, set the struct body.
988  auto structType = deferredStructIt->deferredStructType;
989 
990  assert(structType && "expected a spirv::StructType");
991  assert(structType.isIdentified() && "expected an indentified struct");
992 
993  if (failed(structType.trySetBody(
994  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
995  deferredStructIt->memberDecorationsInfo,
996  deferredStructIt->structDecorationsInfo)))
997  return failure();
998 
999  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1000  } else {
1001  ++deferredStructIt;
1002  }
1003  }
1004 
1005  return success();
1006 }
1007 
1008 LogicalResult
1009 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
1010  if (operands.size() != 3) {
1011  return emitError(unknownLoc,
1012  "OpTypeArray must have element type and count parameters");
1013  }
1014 
1015  Type elementTy = getType(operands[1]);
1016  if (!elementTy) {
1017  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1018  << operands[1];
1019  }
1020 
1021  unsigned count = 0;
1022  // TODO: The count can also come frome a specialization constant.
1023  auto countInfo = getConstant(operands[2]);
1024  if (!countInfo) {
1025  return emitError(unknownLoc, "OpTypeArray count <id> ")
1026  << operands[2] << "can only come from normal constant right now";
1027  }
1028 
1029  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1030  count = intVal.getValue().getZExtValue();
1031  } else {
1032  return emitError(unknownLoc, "OpTypeArray count must come from a "
1033  "scalar integer constant instruction");
1034  }
1035 
1036  typeMap[operands[0]] = spirv::ArrayType::get(
1037  elementTy, count, typeDecorations.lookup(operands[0]));
1038  return success();
1039 }
1040 
1041 LogicalResult
1042 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1043  assert(!operands.empty() && "No operands for processing function type");
1044  if (operands.size() == 1) {
1045  return emitError(unknownLoc, "missing return type for OpTypeFunction");
1046  }
1047  auto returnType = getType(operands[1]);
1048  if (!returnType) {
1049  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1050  }
1051  SmallVector<Type, 1> argTypes;
1052  for (size_t i = 2, e = operands.size(); i < e; ++i) {
1053  auto ty = getType(operands[i]);
1054  if (!ty) {
1055  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1056  }
1057  argTypes.push_back(ty);
1058  }
1059  ArrayRef<Type> returnTypes;
1060  if (!isVoidType(returnType)) {
1061  returnTypes = llvm::ArrayRef(returnType);
1062  }
1063  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1064  return success();
1065 }
1066 
1067 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1068  ArrayRef<uint32_t> operands) {
1069  if (operands.size() != 6) {
1070  return emitError(unknownLoc,
1071  "OpTypeCooperativeMatrixKHR must have element type, "
1072  "scope, row and column parameters, and use");
1073  }
1074 
1075  Type elementTy = getType(operands[1]);
1076  if (!elementTy) {
1077  return emitError(unknownLoc,
1078  "OpTypeCooperativeMatrixKHR references undefined <id> ")
1079  << operands[1];
1080  }
1081 
1082  std::optional<spirv::Scope> scope =
1083  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1084  if (!scope) {
1085  return emitError(
1086  unknownLoc,
1087  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1088  << operands[2];
1089  }
1090 
1091  IntegerAttr rowsAttr = getConstantInt(operands[3]);
1092  IntegerAttr columnsAttr = getConstantInt(operands[4]);
1093  IntegerAttr useAttr = getConstantInt(operands[5]);
1094 
1095  if (!rowsAttr)
1096  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1097  "undefined constant <id> ")
1098  << operands[3];
1099 
1100  if (!columnsAttr)
1101  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1102  "references undefined constant <id> ")
1103  << operands[4];
1104 
1105  if (!useAttr)
1106  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1107  "undefined constant <id> ")
1108  << operands[5];
1109 
1110  unsigned rows = rowsAttr.getInt();
1111  unsigned columns = columnsAttr.getInt();
1112 
1113  std::optional<spirv::CooperativeMatrixUseKHR> use =
1114  spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1115  if (!use) {
1116  return emitError(
1117  unknownLoc,
1118  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1119  << operands[5];
1120  }
1121 
1122  typeMap[operands[0]] =
1123  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1124  return success();
1125 }
1126 
1127 LogicalResult
1128 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1129  if (operands.size() != 2) {
1130  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1131  }
1132  Type memberType = getType(operands[1]);
1133  if (!memberType) {
1134  return emitError(unknownLoc,
1135  "OpTypeRuntimeArray references undefined <id> ")
1136  << operands[1];
1137  }
1138  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1139  memberType, typeDecorations.lookup(operands[0]));
1140  return success();
1141 }
1142 
1143 LogicalResult
1144 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1145  // TODO: Find a way to handle identified structs when debug info is stripped.
1146 
1147  if (operands.empty()) {
1148  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1149  }
1150 
1151  if (operands.size() == 1) {
1152  // Handle empty struct.
1153  typeMap[operands[0]] =
1154  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1155  return success();
1156  }
1157 
1158  // First element is operand ID, second element is member index in the struct.
1159  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1160  SmallVector<Type, 4> memberTypes;
1161 
1162  for (auto op : llvm::drop_begin(operands, 1)) {
1163  Type memberType = getType(op);
1164  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1165 
1166  if (!memberType && !typeForwardPtr)
1167  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1168  << op;
1169 
1170  if (!memberType)
1171  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1172 
1173  memberTypes.push_back(memberType);
1174  }
1175 
1178  if (memberDecorationMap.count(operands[0])) {
1179  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1180  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1181  if (allMemberDecorations.count(memberIndex)) {
1182  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1183  // Check for offset.
1184  if (memberDecoration.first == spirv::Decoration::Offset) {
1185  // If offset info is empty, resize to the number of members;
1186  if (offsetInfo.empty()) {
1187  offsetInfo.resize(memberTypes.size());
1188  }
1189  offsetInfo[memberIndex] = memberDecoration.second[0];
1190  } else {
1191  auto intType = mlir::IntegerType::get(context, 32);
1192  if (!memberDecoration.second.empty()) {
1193  memberDecorationsInfo.emplace_back(
1194  memberIndex, memberDecoration.first,
1195  IntegerAttr::get(intType, memberDecoration.second[0]));
1196  } else {
1197  memberDecorationsInfo.emplace_back(
1198  memberIndex, memberDecoration.first, UnitAttr::get(context));
1199  }
1200  }
1201  }
1202  }
1203  }
1204  }
1205 
1207  if (decorations.count(operands[0])) {
1208  NamedAttrList &allDecorations = decorations[operands[0]];
1209  for (NamedAttribute &decorationAttr : allDecorations) {
1210  std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1211  llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
1212  assert(decoration.has_value());
1213  structDecorationsInfo.emplace_back(decoration.value(),
1214  decorationAttr.getValue());
1215  }
1216  }
1217 
1218  uint32_t structID = operands[0];
1219  std::string structIdentifier = nameMap.lookup(structID).str();
1220 
1221  if (structIdentifier.empty()) {
1222  assert(unresolvedMemberTypes.empty() &&
1223  "didn't expect unresolved member types");
1224  typeMap[structID] = spirv::StructType::get(
1225  memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1226  } else {
1227  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1228  typeMap[structID] = structTy;
1229 
1230  if (!unresolvedMemberTypes.empty())
1231  deferredStructTypesInfos.push_back(
1232  {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1233  memberDecorationsInfo, structDecorationsInfo});
1234  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1235  memberDecorationsInfo,
1236  structDecorationsInfo)))
1237  return failure();
1238  }
1239 
1240  // TODO: Update StructType to have member name as attribute as
1241  // well.
1242  return success();
1243 }
1244 
1245 LogicalResult
1246 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1247  if (operands.size() != 3) {
1248  // Three operands are needed: result_id, column_type, and column_count
1249  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1250  " (result_id, column_type, and column_count)");
1251  }
1252  // Matrix columns must be of vector type
1253  Type elementTy = getType(operands[1]);
1254  if (!elementTy) {
1255  return emitError(unknownLoc,
1256  "OpTypeMatrix references undefined column type.")
1257  << operands[1];
1258  }
1259 
1260  uint32_t colsCount = operands[2];
1261  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1262  return success();
1263 }
1264 
1265 LogicalResult
1266 spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
1267  unsigned size = operands.size();
1268  if (size < 2 || size > 4)
1269  return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
1270  "(result_id, element_type, (rank), (shape)) ")
1271  << size;
1272 
1273  Type elementTy = getType(operands[1]);
1274  if (!elementTy)
1275  return emitError(unknownLoc,
1276  "OpTypeTensorARM references undefined element type ")
1277  << operands[1];
1278 
1279  if (size == 2) {
1280  typeMap[operands[0]] = TensorArmType::get({}, elementTy);
1281  return success();
1282  }
1283 
1284  IntegerAttr rankAttr = getConstantInt(operands[2]);
1285  if (!rankAttr)
1286  return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
1287  "scalar integer constant instruction");
1288  unsigned rank = rankAttr.getValue().getZExtValue();
1289  if (size == 3) {
1290  SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1291  typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1292  return success();
1293  }
1294 
1295  std::optional<std::pair<Attribute, Type>> shapeInfo =
1296  getConstant(operands[3]);
1297  if (!shapeInfo)
1298  return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
1299  "constant instruction of type OpTypeArray");
1300 
1301  ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
1303  for (auto dimAttr : shapeArrayAttr.getValue()) {
1304  auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
1305  if (!dimIntAttr)
1306  return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
1307  "dimension size");
1308  shape.push_back(dimIntAttr.getValue().getSExtValue());
1309  }
1310  typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1311  return success();
1312 }
1313 
1314 LogicalResult
1315 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1316  if (operands.size() != 2)
1317  return emitError(unknownLoc,
1318  "OpTypeForwardPointer instruction must have two operands");
1319 
1320  typeForwardPointerIDs.insert(operands[0]);
1321  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1322  // instruction that defines the actual type.
1323 
1324  return success();
1325 }
1326 
1327 LogicalResult
1328 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1329  // TODO: Add support for Access Qualifier.
1330  if (operands.size() != 8)
1331  return emitError(
1332  unknownLoc,
1333  "OpTypeImage with non-eight operands are not supported yet");
1334 
1335  Type elementTy = getType(operands[1]);
1336  if (!elementTy)
1337  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1338  << operands[1];
1339 
1340  auto dim = spirv::symbolizeDim(operands[2]);
1341  if (!dim)
1342  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1343  << operands[2];
1344 
1345  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1346  if (!depthInfo)
1347  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1348  << operands[3];
1349 
1350  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1351  if (!arrayedInfo)
1352  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1353  << operands[4];
1354 
1355  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1356  if (!samplingInfo)
1357  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1358 
1359  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1360  if (!samplerUseInfo)
1361  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1362  << operands[6];
1363 
1364  auto format = spirv::symbolizeImageFormat(operands[7]);
1365  if (!format)
1366  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1367  << operands[7];
1368 
1369  typeMap[operands[0]] = spirv::ImageType::get(
1370  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1371  samplingInfo.value(), samplerUseInfo.value(), format.value());
1372  return success();
1373 }
1374 
1375 LogicalResult
1376 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1377  if (operands.size() != 2)
1378  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1379 
1380  Type elementTy = getType(operands[1]);
1381  if (!elementTy)
1382  return emitError(unknownLoc,
1383  "OpTypeSampledImage references undefined <id>: ")
1384  << operands[1];
1385 
1386  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1387  return success();
1388 }
1389 
1390 //===----------------------------------------------------------------------===//
1391 // Constant
1392 //===----------------------------------------------------------------------===//
1393 
1394 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1395  bool isSpec) {
1396  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1397 
1398  if (operands.size() < 2) {
1399  return emitError(unknownLoc)
1400  << opname << " must have type <id> and result <id>";
1401  }
1402  if (operands.size() < 3) {
1403  return emitError(unknownLoc)
1404  << opname << " must have at least 1 more parameter";
1405  }
1406 
1407  Type resultType = getType(operands[0]);
1408  if (!resultType) {
1409  return emitError(unknownLoc, "undefined result type from <id> ")
1410  << operands[0];
1411  }
1412 
1413  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1414  if (bitwidth == 64) {
1415  if (operands.size() == 4) {
1416  return success();
1417  }
1418  return emitError(unknownLoc)
1419  << opname << " should have 2 parameters for 64-bit values";
1420  }
1421  if (bitwidth <= 32) {
1422  if (operands.size() == 3) {
1423  return success();
1424  }
1425 
1426  return emitError(unknownLoc)
1427  << opname
1428  << " should have 1 parameter for values with no more than 32 bits";
1429  }
1430  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1431  << bitwidth;
1432  };
1433 
1434  auto resultID = operands[1];
1435 
1436  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1437  auto bitwidth = intType.getWidth();
1438  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1439  return failure();
1440  }
1441 
1442  APInt value;
1443  if (bitwidth == 64) {
1444  // 64-bit integers are represented with two SPIR-V words. According to
1445  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1446  // literal’s low-order words appear first."
1447  struct DoubleWord {
1448  uint32_t word1;
1449  uint32_t word2;
1450  } words = {operands[2], operands[3]};
1451  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1452  } else if (bitwidth <= 32) {
1453  value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1454  /*implicitTrunc=*/true);
1455  }
1456 
1457  auto attr = opBuilder.getIntegerAttr(intType, value);
1458 
1459  if (isSpec) {
1460  createSpecConstant(unknownLoc, resultID, attr);
1461  } else {
1462  // For normal constants, we just record the attribute (and its type) for
1463  // later materialization at use sites.
1464  constantMap.try_emplace(resultID, attr, intType);
1465  }
1466 
1467  return success();
1468  }
1469 
1470  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1471  auto bitwidth = floatType.getWidth();
1472  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1473  return failure();
1474  }
1475 
1476  APFloat value(0.f);
1477  if (floatType.isF64()) {
1478  // Double values are represented with two SPIR-V words. According to
1479  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1480  // literal’s low-order words appear first."
1481  struct DoubleWord {
1482  uint32_t word1;
1483  uint32_t word2;
1484  } words = {operands[2], operands[3]};
1485  value = APFloat(llvm::bit_cast<double>(words));
1486  } else if (floatType.isF32()) {
1487  value = APFloat(llvm::bit_cast<float>(operands[2]));
1488  } else if (floatType.isF16()) {
1489  APInt data(16, operands[2]);
1490  value = APFloat(APFloat::IEEEhalf(), data);
1491  } else if (floatType.isBF16()) {
1492  APInt data(16, operands[2]);
1493  value = APFloat(APFloat::BFloat(), data);
1494  }
1495 
1496  auto attr = opBuilder.getFloatAttr(floatType, value);
1497  if (isSpec) {
1498  createSpecConstant(unknownLoc, resultID, attr);
1499  } else {
1500  // For normal constants, we just record the attribute (and its type) for
1501  // later materialization at use sites.
1502  constantMap.try_emplace(resultID, attr, floatType);
1503  }
1504 
1505  return success();
1506  }
1507 
1508  return emitError(unknownLoc, "OpConstant can only generate values of "
1509  "scalar integer or floating-point type");
1510 }
1511 
1512 LogicalResult spirv::Deserializer::processConstantBool(
1513  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1514  if (operands.size() != 2) {
1515  return emitError(unknownLoc, "Op")
1516  << (isSpec ? "Spec" : "") << "Constant"
1517  << (isTrue ? "True" : "False")
1518  << " must have type <id> and result <id>";
1519  }
1520 
1521  auto attr = opBuilder.getBoolAttr(isTrue);
1522  auto resultID = operands[1];
1523  if (isSpec) {
1524  createSpecConstant(unknownLoc, resultID, attr);
1525  } else {
1526  // For normal constants, we just record the attribute (and its type) for
1527  // later materialization at use sites.
1528  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1529  }
1530 
1531  return success();
1532 }
1533 
1534 LogicalResult
1535 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1536  if (operands.size() < 2) {
1537  return emitError(unknownLoc,
1538  "OpConstantComposite must have type <id> and result <id>");
1539  }
1540  if (operands.size() < 3) {
1541  return emitError(unknownLoc,
1542  "OpConstantComposite must have at least 1 parameter");
1543  }
1544 
1545  Type resultType = getType(operands[0]);
1546  if (!resultType) {
1547  return emitError(unknownLoc, "undefined result type from <id> ")
1548  << operands[0];
1549  }
1550 
1551  SmallVector<Attribute, 4> elements;
1552  elements.reserve(operands.size() - 2);
1553  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1554  auto elementInfo = getConstant(operands[i]);
1555  if (!elementInfo) {
1556  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1557  << operands[i] << " must come from a normal constant";
1558  }
1559  elements.push_back(elementInfo->first);
1560  }
1561 
1562  auto resultID = operands[1];
1563  if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1564  SmallVector<Attribute> flattenedElems;
1565  for (Attribute element : elements) {
1566  if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1567  for (auto value : denseElemAttr.getValues<Attribute>())
1568  flattenedElems.push_back(value);
1569  } else {
1570  flattenedElems.push_back(element);
1571  }
1572  }
1573  auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
1574  constantMap.try_emplace(resultID, attr, tensorType);
1575  } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1576  auto attr = DenseElementsAttr::get(shapedType, elements);
1577  // For normal constants, we just record the attribute (and its type) for
1578  // later materialization at use sites.
1579  constantMap.try_emplace(resultID, attr, shapedType);
1580  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1581  auto attr = opBuilder.getArrayAttr(elements);
1582  constantMap.try_emplace(resultID, attr, resultType);
1583  } else {
1584  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1585  << resultType;
1586  }
1587 
1588  return success();
1589 }
1590 
1591 LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
1592  ArrayRef<uint32_t> operands) {
1593  if (operands.size() != 3) {
1594  return emitError(
1595  unknownLoc,
1596  "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1597  << operands.size();
1598  }
1599 
1600  Type resultType = getType(operands[0]);
1601  if (!resultType) {
1602  return emitError(unknownLoc, "undefined result type from <id> ")
1603  << operands[0];
1604  }
1605 
1606  auto compositeType = dyn_cast<CompositeType>(resultType);
1607  if (!compositeType) {
1608  return emitError(unknownLoc,
1609  "result type from <id> is not a composite type")
1610  << operands[0];
1611  }
1612 
1613  uint32_t resultID = operands[1];
1614  uint32_t constantID = operands[2];
1615 
1616  std::optional<std::pair<Attribute, Type>> constantInfo =
1617  getConstant(constantID);
1618  if (constantInfo.has_value()) {
1619  constantCompositeReplicateMap.try_emplace(
1620  resultID, constantInfo.value().first, resultType);
1621  return success();
1622  }
1623 
1624  std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1625  getConstantCompositeReplicate(constantID);
1626  if (replicatedConstantCompositeInfo.has_value()) {
1627  constantCompositeReplicateMap.try_emplace(
1628  resultID, replicatedConstantCompositeInfo.value().first, resultType);
1629  return success();
1630  }
1631 
1632  return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
1633  << constantID
1634  << " must come from a normal constant or a "
1635  "OpConstantCompositeReplicateEXT";
1636 }
1637 
1638 LogicalResult
1639 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1640  if (operands.size() < 2) {
1641  return emitError(
1642  unknownLoc,
1643  "OpSpecConstantComposite must have type <id> and result <id>");
1644  }
1645  if (operands.size() < 3) {
1646  return emitError(unknownLoc,
1647  "OpSpecConstantComposite must have at least 1 parameter");
1648  }
1649 
1650  Type resultType = getType(operands[0]);
1651  if (!resultType) {
1652  return emitError(unknownLoc, "undefined result type from <id> ")
1653  << operands[0];
1654  }
1655 
1656  auto resultID = operands[1];
1657  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1658 
1659  SmallVector<Attribute, 4> elements;
1660  elements.reserve(operands.size() - 2);
1661  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1662  auto elementInfo = getSpecConstant(operands[i]);
1663  elements.push_back(SymbolRefAttr::get(elementInfo));
1664  }
1665 
1666  auto op = spirv::SpecConstantCompositeOp::create(
1667  opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1668  opBuilder.getArrayAttr(elements));
1669  specConstCompositeMap[resultID] = op;
1670 
1671  return success();
1672 }
1673 
1674 LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
1675  ArrayRef<uint32_t> operands) {
1676  if (operands.size() != 3) {
1677  return emitError(unknownLoc, "OpSpecConstantCompositeReplicateEXT expects "
1678  "3 operands but found ")
1679  << operands.size();
1680  }
1681 
1682  Type resultType = getType(operands[0]);
1683  if (!resultType) {
1684  return emitError(unknownLoc, "undefined result type from <id> ")
1685  << operands[0];
1686  }
1687 
1688  auto compositeType = dyn_cast<CompositeType>(resultType);
1689  if (!compositeType) {
1690  return emitError(unknownLoc,
1691  "result type from <id> is not a composite type")
1692  << operands[0];
1693  }
1694 
1695  uint32_t resultID = operands[1];
1696 
1697  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1698  spirv::SpecConstantOp constituentSpecConstantOp =
1699  getSpecConstant(operands[2]);
1700  auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1701  opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1702  SymbolRefAttr::get(constituentSpecConstantOp));
1703 
1704  specConstCompositeReplicateMap[resultID] = op;
1705 
1706  return success();
1707 }
1708 
1709 LogicalResult
1710 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1711  if (operands.size() < 3)
1712  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1713  "result <id>, and operand opcode");
1714 
1715  uint32_t resultTypeID = operands[0];
1716 
1717  if (!getType(resultTypeID))
1718  return emitError(unknownLoc, "undefined result type from <id> ")
1719  << resultTypeID;
1720 
1721  uint32_t resultID = operands[1];
1722  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1723  auto emplaceResult = specConstOperationMap.try_emplace(
1724  resultID,
1725  SpecConstOperationMaterializationInfo{
1726  enclosedOpcode, resultTypeID,
1727  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1728 
1729  if (!emplaceResult.second)
1730  return emitError(unknownLoc, "value with <id>: ")
1731  << resultID << " is probably defined before.";
1732 
1733  return success();
1734 }
1735 
1736 Value spirv::Deserializer::materializeSpecConstantOperation(
1737  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1738  ArrayRef<uint32_t> enclosedOpOperands) {
1739 
1740  Type resultType = getType(resultTypeID);
1741 
1742  // Instructions wrapped by OpSpecConstantOp need an ID for their
1743  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1744  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1745  // ID in that map is assigned to the result of the enclosed instruction. Note
1746  // that there is no need to update this fake ID since we only need to
1747  // reference the created Value for the enclosed op from the spv::YieldOp
1748  // created later in this method (both of which are the only values in their
1749  // region: the SpecConstantOperation's region). If we encounter another
1750  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1751  // previous Value assigned to it isn't visible in the current scope anyway.
1752  DenseMap<uint32_t, Value> newValueMap;
1753  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1754  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1755 
1756  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1757  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1758  enclosedOpResultTypeAndOperands.push_back(fakeID);
1759  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1760  enclosedOpOperands.end());
1761 
1762  // Process enclosed instruction before creating the enclosing
1763  // specConstantOperation (and its region). This way, references to constants,
1764  // global variables, and spec constants will be materialized outside the new
1765  // op's region. For more info, see Deserializer::getValue's implementation.
1766  if (failed(
1767  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1768  return Value();
1769 
1770  // Since the enclosed op is emitted in the current block, split it in a
1771  // separate new block.
1772  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1773 
1774  auto loc = createFileLineColLoc(opBuilder);
1775  auto specConstOperationOp =
1776  spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
1777 
1778  Region &body = specConstOperationOp.getBody();
1779  // Move the new block into SpecConstantOperation's body.
1780  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1781  Region::iterator(enclosedBlock));
1782  Block &block = body.back();
1783 
1784  // RAII guard to reset the insertion point to the module's region after
1785  // deserializing the body of the specConstantOperation.
1786  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1787  opBuilder.setInsertionPointToEnd(&block);
1788 
1789  spirv::YieldOp::create(opBuilder, loc, block.front().getResult(0));
1790  return specConstOperationOp.getResult();
1791 }
1792 
1793 LogicalResult
1794 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1795  if (operands.size() != 2) {
1796  return emitError(unknownLoc,
1797  "OpConstantNull must only have type <id> and result <id>");
1798  }
1799 
1800  Type resultType = getType(operands[0]);
1801  if (!resultType) {
1802  return emitError(unknownLoc, "undefined result type from <id> ")
1803  << operands[0];
1804  }
1805 
1806  auto resultID = operands[1];
1807  Attribute attr;
1808  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1809  attr = opBuilder.getZeroAttr(resultType);
1810  } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1811  if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
1812  attr = DenseElementsAttr::get(tensorType, element);
1813  }
1814 
1815  if (attr) {
1816  // For normal constants, we just record the attribute (and its type) for
1817  // later materialization at use sites.
1818  constantMap.try_emplace(resultID, attr, resultType);
1819  return success();
1820  }
1821 
1822  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1823  << resultType;
1824 }
1825 
1826 //===----------------------------------------------------------------------===//
1827 // Control flow
1828 //===----------------------------------------------------------------------===//
1829 
1830 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1831  if (auto *block = getBlock(id)) {
1832  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1833  << " @ " << block << "\n");
1834  return block;
1835  }
1836 
1837  // We don't know where this block will be placed finally (in a
1838  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1839  // function for now and sort out the proper place later.
1840  auto *block = curFunction->addBlock();
1841  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1842  << " @ " << block << "\n");
1843  return blockMap[id] = block;
1844 }
1845 
1846 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1847  if (!curBlock) {
1848  return emitError(unknownLoc, "OpBranch must appear inside a block");
1849  }
1850 
1851  if (operands.size() != 1) {
1852  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1853  }
1854 
1855  auto *target = getOrCreateBlock(operands[0]);
1856  auto loc = createFileLineColLoc(opBuilder);
1857  // The preceding instruction for the OpBranch instruction could be an
1858  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1859  // the same OpLine information.
1860  spirv::BranchOp::create(opBuilder, loc, target);
1861 
1862  clearDebugLine();
1863  return success();
1864 }
1865 
1866 LogicalResult
1867 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1868  if (!curBlock) {
1869  return emitError(unknownLoc,
1870  "OpBranchConditional must appear inside a block");
1871  }
1872 
1873  if (operands.size() != 3 && operands.size() != 5) {
1874  return emitError(unknownLoc,
1875  "OpBranchConditional must have condition, true label, "
1876  "false label, and optionally two branch weights");
1877  }
1878 
1879  auto condition = getValue(operands[0]);
1880  auto *trueBlock = getOrCreateBlock(operands[1]);
1881  auto *falseBlock = getOrCreateBlock(operands[2]);
1882 
1883  std::optional<std::pair<uint32_t, uint32_t>> weights;
1884  if (operands.size() == 5) {
1885  weights = std::make_pair(operands[3], operands[4]);
1886  }
1887  // The preceding instruction for the OpBranchConditional instruction could be
1888  // an OpSelectionMerge instruction, in this case they will have the same
1889  // OpLine information.
1890  auto loc = createFileLineColLoc(opBuilder);
1891  spirv::BranchConditionalOp::create(
1892  opBuilder, loc, condition, trueBlock,
1893  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1894  /*falseArguments=*/ArrayRef<Value>(), weights);
1895 
1896  clearDebugLine();
1897  return success();
1898 }
1899 
1900 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1901  if (!curFunction) {
1902  return emitError(unknownLoc, "OpLabel must appear inside a function");
1903  }
1904 
1905  if (operands.size() != 1) {
1906  return emitError(unknownLoc, "OpLabel should only have result <id>");
1907  }
1908 
1909  auto labelID = operands[0];
1910  // We may have forward declared this block.
1911  auto *block = getOrCreateBlock(labelID);
1912  LLVM_DEBUG(logger.startLine()
1913  << "[block] populating block " << block << "\n");
1914  // If we have seen this block, make sure it was just a forward declaration.
1915  assert(block->empty() && "re-deserialize the same block!");
1916 
1917  opBuilder.setInsertionPointToStart(block);
1918  blockMap[labelID] = curBlock = block;
1919 
1920  return success();
1921 }
1922 
1923 LogicalResult
1924 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1925  if (!curBlock) {
1926  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1927  }
1928 
1929  if (operands.size() < 2) {
1930  return emitError(
1931  unknownLoc,
1932  "OpSelectionMerge must specify merge target and selection control");
1933  }
1934 
1935  auto *mergeBlock = getOrCreateBlock(operands[0]);
1936  auto loc = createFileLineColLoc(opBuilder);
1937  auto selectionControl = operands[1];
1938 
1939  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1940  .second) {
1941  return emitError(
1942  unknownLoc,
1943  "a block cannot have more than one OpSelectionMerge instruction");
1944  }
1945 
1946  return success();
1947 }
1948 
1949 LogicalResult
1950 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1951  if (!curBlock) {
1952  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1953  }
1954 
1955  if (operands.size() < 3) {
1956  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1957  "continue target and loop control");
1958  }
1959 
1960  auto *mergeBlock = getOrCreateBlock(operands[0]);
1961  auto *continueBlock = getOrCreateBlock(operands[1]);
1962  auto loc = createFileLineColLoc(opBuilder);
1963  uint32_t loopControl = operands[2];
1964 
1965  if (!blockMergeInfo
1966  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1967  .second) {
1968  return emitError(
1969  unknownLoc,
1970  "a block cannot have more than one OpLoopMerge instruction");
1971  }
1972 
1973  return success();
1974 }
1975 
1976 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1977  if (!curBlock) {
1978  return emitError(unknownLoc, "OpPhi must appear in a block");
1979  }
1980 
1981  if (operands.size() < 4) {
1982  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1983  "and variable-parent pairs");
1984  }
1985 
1986  // Create a block argument for this OpPhi instruction.
1987  Type blockArgType = getType(operands[0]);
1988  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1989  valueMap[operands[1]] = blockArg;
1990  LLVM_DEBUG(logger.startLine()
1991  << "[phi] created block argument " << blockArg
1992  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1993 
1994  // For each (value, predecessor) pair, insert the value to the predecessor's
1995  // blockPhiInfo entry so later we can fix the block argument there.
1996  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1997  uint32_t value = operands[i];
1998  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1999  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2000  blockPhiInfo[predecessorTargetPair].push_back(value);
2001  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
2002  << " with arg id = " << value << "\n");
2003  }
2004 
2005  return success();
2006 }
2007 
2008 namespace {
2009 /// A class for putting all blocks in a structured selection/loop in a
2010 /// spirv.mlir.selection/spirv.mlir.loop op.
2011 class ControlFlowStructurizer {
2012 public:
2013 #ifndef NDEBUG
2014  ControlFlowStructurizer(Location loc, uint32_t control,
2015  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2016  Block *merge, Block *cont,
2017  llvm::ScopedPrinter &logger)
2018  : location(loc), control(control), blockMergeInfo(mergeInfo),
2019  headerBlock(header), mergeBlock(merge), continueBlock(cont),
2020  logger(logger) {}
2021 #else
2022  ControlFlowStructurizer(Location loc, uint32_t control,
2023  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2024  Block *merge, Block *cont)
2025  : location(loc), control(control), blockMergeInfo(mergeInfo),
2026  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2027 #endif
2028 
2029  /// Structurizes the loop at the given `headerBlock`.
2030  ///
2031  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
2032  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
2033  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
2034  /// method will also update `mergeInfo` by remapping all blocks inside to the
2035  /// newly cloned ones inside structured control flow op's regions.
2036  LogicalResult structurize();
2037 
2038 private:
2039  /// Creates a new spirv.mlir.selection op at the beginning of the
2040  /// `mergeBlock`.
2041  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2042 
2043  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
2044  spirv::LoopOp createLoopOp(uint32_t loopControl);
2045 
2046  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
2047  void collectBlocksInConstruct();
2048 
2049  Location location;
2050  uint32_t control;
2051 
2052  spirv::BlockMergeInfoMap &blockMergeInfo;
2053 
2054  Block *headerBlock;
2055  Block *mergeBlock;
2056  Block *continueBlock; // nullptr for spirv.mlir.selection
2057 
2058  SetVector<Block *> constructBlocks;
2059 
2060 #ifndef NDEBUG
2061  /// A logger used to emit information during the deserialzation process.
2062  llvm::ScopedPrinter &logger;
2063 #endif
2064 };
2065 } // namespace
2066 
2067 spirv::SelectionOp
2068 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2069  // Create a builder and set the insertion point to the beginning of the
2070  // merge block so that the newly created SelectionOp will be inserted there.
2071  OpBuilder builder(&mergeBlock->front());
2072 
2073  auto control = static_cast<spirv::SelectionControl>(selectionControl);
2074  auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2075  selectionOp.addMergeBlock(builder);
2076 
2077  return selectionOp;
2078 }
2079 
2080 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2081  // Create a builder and set the insertion point to the beginning of the
2082  // merge block so that the newly created LoopOp will be inserted there.
2083  OpBuilder builder(&mergeBlock->front());
2084 
2085  auto control = static_cast<spirv::LoopControl>(loopControl);
2086  auto loopOp = spirv::LoopOp::create(builder, location, control);
2087  loopOp.addEntryAndMergeBlock(builder);
2088 
2089  return loopOp;
2090 }
2091 
2092 void ControlFlowStructurizer::collectBlocksInConstruct() {
2093  assert(constructBlocks.empty() && "expected empty constructBlocks");
2094 
2095  // Put the header block in the work list first.
2096  constructBlocks.insert(headerBlock);
2097 
2098  // For each item in the work list, add its successors excluding the merge
2099  // block.
2100  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
2101  for (auto *successor : constructBlocks[i]->getSuccessors())
2102  if (successor != mergeBlock)
2103  constructBlocks.insert(successor);
2104  }
2105 }
2106 
2107 LogicalResult ControlFlowStructurizer::structurize() {
2108  Operation *op = nullptr;
2109  bool isLoop = continueBlock != nullptr;
2110  if (isLoop) {
2111  if (auto loopOp = createLoopOp(control))
2112  op = loopOp.getOperation();
2113  } else {
2114  if (auto selectionOp = createSelectionOp(control))
2115  op = selectionOp.getOperation();
2116  }
2117  if (!op)
2118  return failure();
2119  Region &body = op->getRegion(0);
2120 
2121  IRMapping mapper;
2122  // All references to the old merge block should be directed to the
2123  // selection/loop merge block in the SelectionOp/LoopOp's region.
2124  mapper.map(mergeBlock, &body.back());
2125 
2126  collectBlocksInConstruct();
2127 
2128  // We've identified all blocks belonging to the selection/loop's region. Now
2129  // need to "move" them into the selection/loop. Instead of really moving the
2130  // blocks, in the following we copy them and remap all values and branches.
2131  // This is because:
2132  // * Inserting a block into a region requires the block not in any region
2133  // before. But selections/loops can nest so we can create selection/loop ops
2134  // in a nested manner, which means some blocks may already be in a
2135  // selection/loop region when to be moved again.
2136  // * It's much trickier to fix up the branches into and out of the loop's
2137  // region: we need to treat not-moved blocks and moved blocks differently:
2138  // Not-moved blocks jumping to the loop header block need to jump to the
2139  // merge point containing the new loop op but not the loop continue block's
2140  // back edge. Moved blocks jumping out of the loop need to jump to the
2141  // merge block inside the loop region but not other not-moved blocks.
2142  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2143  // logic.
2144 
2145  // Create a corresponding block in the SelectionOp/LoopOp's region for each
2146  // block in this loop construct.
2147  OpBuilder builder(body);
2148  for (auto *block : constructBlocks) {
2149  // Create a block and insert it before the selection/loop merge block in the
2150  // SelectionOp/LoopOp's region.
2151  auto *newBlock = builder.createBlock(&body.back());
2152  mapper.map(block, newBlock);
2153  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2154  << " from block " << block << "\n");
2155  if (!isFnEntryBlock(block)) {
2156  for (BlockArgument blockArg : block->getArguments()) {
2157  auto newArg =
2158  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2159  mapper.map(blockArg, newArg);
2160  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2161  << blockArg << " to " << newArg << "\n");
2162  }
2163  } else {
2164  LLVM_DEBUG(logger.startLine()
2165  << "[cf] block " << block << " is a function entry block\n");
2166  }
2167 
2168  for (auto &op : *block)
2169  newBlock->push_back(op.clone(mapper));
2170  }
2171 
2172  // Go through all ops and remap the operands.
2173  auto remapOperands = [&](Operation *op) {
2174  for (auto &operand : op->getOpOperands())
2175  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2176  operand.set(mappedOp);
2177  for (auto &succOp : op->getBlockOperands())
2178  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2179  succOp.set(mappedOp);
2180  };
2181  for (auto &block : body)
2182  block.walk(remapOperands);
2183 
2184  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2185  // the selection/loop construct into its region. Next we need to fix the
2186  // connections between this new SelectionOp/LoopOp with existing blocks.
2187 
2188  // All existing incoming branches should go to the merge block, where the
2189  // SelectionOp/LoopOp resides right now.
2190  headerBlock->replaceAllUsesWith(mergeBlock);
2191 
2192  LLVM_DEBUG({
2193  logger.startLine() << "[cf] after cloning and fixing references:\n";
2194  headerBlock->getParentOp()->print(logger.getOStream());
2195  logger.startLine() << "\n";
2196  });
2197 
2198  if (isLoop) {
2199  if (!mergeBlock->args_empty()) {
2200  return mergeBlock->getParentOp()->emitError(
2201  "OpPhi in loop merge block unsupported");
2202  }
2203 
2204  // The loop header block may have block arguments. Since now we place the
2205  // loop op inside the old merge block, we need to make sure the old merge
2206  // block has the same block argument list.
2207  for (BlockArgument blockArg : headerBlock->getArguments())
2208  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2209 
2210  // If the loop header block has block arguments, make sure the spirv.Branch
2211  // op matches.
2212  SmallVector<Value, 4> blockArgs;
2213  if (!headerBlock->args_empty())
2214  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2215 
2216  // The loop entry block should have a unconditional branch jumping to the
2217  // loop header block.
2218  builder.setInsertionPointToEnd(&body.front());
2219  spirv::BranchOp::create(builder, location, mapper.lookupOrNull(headerBlock),
2220  ArrayRef<Value>(blockArgs));
2221  }
2222 
2223  // Values defined inside the selection region that need to be yielded outside
2224  // the region.
2225  SmallVector<Value> valuesToYield;
2226  // Outside uses of values that were sunk into the selection region. Those uses
2227  // will be replaced with values returned by the SelectionOp.
2228  SmallVector<Value> outsideUses;
2229 
2230  // Move block arguments of the original block (`mergeBlock`) into the merge
2231  // block inside the selection (`body.back()`). Values produced by block
2232  // arguments will be yielded by the selection region. We do not update uses or
2233  // erase original block arguments yet. It will be done later in the code.
2234  //
2235  // Code below is not executed for loops as it would interfere with the logic
2236  // above. Currently block arguments in the merge block are not supported, but
2237  // instead, the code above copies those arguments from the header block into
2238  // the merge block. As such, running the code would yield those copied
2239  // arguments that is most likely not a desired behaviour. This may need to be
2240  // revisited in the future.
2241  if (!isLoop)
2242  for (BlockArgument blockArg : mergeBlock->getArguments()) {
2243  // Create new block arguments in the last block ("merge block") of the
2244  // selection region. We create one argument for each argument in
2245  // `mergeBlock`. This new value will need to be yielded, and the original
2246  // value replaced, so add them to appropriate vectors.
2247  body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2248  valuesToYield.push_back(body.back().getArguments().back());
2249  outsideUses.push_back(blockArg);
2250  }
2251 
2252  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2253  // cleaned up.
2254  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2255  // First we need to drop all operands' references inside all blocks. This is
2256  // needed because we can have blocks referencing SSA values from one another.
2257  for (auto *block : constructBlocks)
2258  block->dropAllReferences();
2259 
2260  // All internal uses should be removed from original blocks by now, so
2261  // whatever is left is an outside use and will need to be yielded from
2262  // the newly created selection / loop region.
2263  for (Block *block : constructBlocks) {
2264  for (Operation &op : *block) {
2265  if (!op.use_empty())
2266  for (Value result : op.getResults()) {
2267  valuesToYield.push_back(mapper.lookupOrNull(result));
2268  outsideUses.push_back(result);
2269  }
2270  }
2271  for (BlockArgument &arg : block->getArguments()) {
2272  if (!arg.use_empty()) {
2273  valuesToYield.push_back(mapper.lookupOrNull(arg));
2274  outsideUses.push_back(arg);
2275  }
2276  }
2277  }
2278 
2279  assert(valuesToYield.size() == outsideUses.size());
2280 
2281  // If we need to yield any values from the selection / loop region we will
2282  // take care of it here.
2283  if (!valuesToYield.empty()) {
2284  LLVM_DEBUG(logger.startLine()
2285  << "[cf] yielding values from the selection / loop region\n");
2286 
2287  // Update `mlir.merge` with values to be yield.
2288  auto mergeOps = body.back().getOps<spirv::MergeOp>();
2289  Operation *merge = llvm::getSingleElement(mergeOps);
2290  assert(merge);
2291  merge->setOperands(valuesToYield);
2292 
2293  // MLIR does not allow changing the number of results of an operation, so
2294  // we create a new SelectionOp / LoopOp with required list of results and
2295  // move the region from the initial SelectionOp / LoopOp. The initial
2296  // operation is then removed. Since we move the region to the new op all
2297  // links between blocks and remapping we have previously done should be
2298  // preserved.
2299  builder.setInsertionPoint(&mergeBlock->front());
2300 
2301  Operation *newOp = nullptr;
2302 
2303  if (isLoop)
2304  newOp = spirv::LoopOp::create(builder, location,
2305  TypeRange(ValueRange(outsideUses)),
2306  static_cast<spirv::LoopControl>(control));
2307  else
2308  newOp = spirv::SelectionOp::create(
2309  builder, location, TypeRange(ValueRange(outsideUses)),
2310  static_cast<spirv::SelectionControl>(control));
2311 
2312  newOp->getRegion(0).takeBody(body);
2313 
2314  // Remove initial op and swap the pointer to the newly created one.
2315  op->erase();
2316  op = newOp;
2317 
2318  // Update all outside uses to use results of the SelectionOp / LoopOp and
2319  // remove block arguments from the original merge block.
2320  for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2321  outsideUses[i].replaceAllUsesWith(op->getResult(i));
2322 
2323  // We do not support block arguments in loop merge block. Also running this
2324  // function with loop would break some of the loop specific code above
2325  // dealing with block arguments.
2326  if (!isLoop)
2327  mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2328  }
2329 
2330  // Check that whether some op in the to-be-erased blocks still has uses. Those
2331  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2332  // region. We cannot handle such cases given that once a value is sinked into
2333  // the SelectionOp/LoopOp's region, there is no escape for it.
2334  for (auto *block : constructBlocks) {
2335  for (Operation &op : *block)
2336  if (!op.use_empty())
2337  return op.emitOpError("failed control flow structurization: value has "
2338  "uses outside of the "
2339  "enclosing selection/loop construct");
2340  for (BlockArgument &arg : block->getArguments())
2341  if (!arg.use_empty())
2342  return emitError(arg.getLoc(), "failed control flow structurization: "
2343  "block argument has uses outside of the "
2344  "enclosing selection/loop construct");
2345  }
2346 
2347  // Then erase all old blocks.
2348  for (auto *block : constructBlocks) {
2349  // We've cloned all blocks belonging to this construct into the structured
2350  // control flow op's region. Among these blocks, some may compose another
2351  // selection/loop. If so, they will be recorded within blockMergeInfo.
2352  // We need to update the pointers there to the newly remapped ones so we can
2353  // continue structurizing them later.
2354  //
2355  // We need to walk each block as constructBlocks do not include blocks
2356  // internal to ops already structured within those blocks. It is not
2357  // fully clear to me why the mergeInfo of blocks (yet to be structured)
2358  // inside already structured selections/loops get invalidated and needs
2359  // updating, however the following example code can cause a crash (depending
2360  // on the structuring order), when the most inner selection is being
2361  // structured after the outer selection and loop have been already
2362  // structured:
2363  //
2364  // spirv.mlir.for {
2365  // // ...
2366  // spirv.mlir.selection {
2367  // // ..
2368  // // A selection region that hasn't been yet structured!
2369  // // ..
2370  // }
2371  // // ...
2372  // }
2373  //
2374  // If the loop gets structured after the outer selection, but before the
2375  // inner selection. Moving the already structured selection inside the loop
2376  // will invalidate the mergeInfo of the region that is not yet structured.
2377  // Just going over constructBlocks will not check and updated header blocks
2378  // inside the already structured selection region. Walking block fixes that.
2379  //
2380  // TODO: If structuring was done in a fixed order starting with inner
2381  // most constructs this most likely not be an issue and the whole code
2382  // section could be removed. However, with the current non-deterministic
2383  // order this is not possible.
2384  //
2385  // TODO: The asserts in the following assumes input SPIR-V blob forms
2386  // correctly nested selection/loop constructs. We should relax this and
2387  // support error cases better.
2388  auto updateMergeInfo = [&](Block *block) -> WalkResult {
2389  auto it = blockMergeInfo.find(block);
2390  if (it != blockMergeInfo.end()) {
2391  // Use the original location for nested selection/loop ops.
2392  Location loc = it->second.loc;
2393 
2394  Block *newHeader = mapper.lookupOrNull(block);
2395  if (!newHeader)
2396  return emitError(loc, "failed control flow structurization: nested "
2397  "loop header block should be remapped!");
2398 
2399  Block *newContinue = it->second.continueBlock;
2400  if (newContinue) {
2401  newContinue = mapper.lookupOrNull(newContinue);
2402  if (!newContinue)
2403  return emitError(loc, "failed control flow structurization: nested "
2404  "loop continue block should be remapped!");
2405  }
2406 
2407  Block *newMerge = it->second.mergeBlock;
2408  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2409  newMerge = mappedTo;
2410 
2411  // The iterator should be erased before adding a new entry into
2412  // blockMergeInfo to avoid iterator invalidation.
2413  blockMergeInfo.erase(it);
2414  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2415  newContinue);
2416  }
2417 
2418  return WalkResult::advance();
2419  };
2420 
2421  if (block->walk(updateMergeInfo).wasInterrupted())
2422  return failure();
2423 
2424  // The structured selection/loop's entry block does not have arguments.
2425  // If the function's header block is also part of the structured control
2426  // flow, we cannot just simply erase it because it may contain arguments
2427  // matching the function signature and used by the cloned blocks.
2428  if (isFnEntryBlock(block)) {
2429  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2430  << " to only contain a spirv.Branch op\n");
2431  // Still keep the function entry block for the potential block arguments,
2432  // but replace all ops inside with a branch to the merge block.
2433  block->clear();
2434  builder.setInsertionPointToEnd(block);
2435  spirv::BranchOp::create(builder, location, mergeBlock);
2436  } else {
2437  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2438  block->erase();
2439  }
2440  }
2441 
2442  LLVM_DEBUG(logger.startLine()
2443  << "[cf] after structurizing construct with header block "
2444  << headerBlock << ":\n"
2445  << *op << "\n");
2446 
2447  return success();
2448 }
2449 
2450 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2451  LLVM_DEBUG({
2452  logger.startLine()
2453  << "//----- [phi] start wiring up block arguments -----//\n";
2454  logger.indent();
2455  });
2456 
2457  OpBuilder::InsertionGuard guard(opBuilder);
2458 
2459  for (const auto &info : blockPhiInfo) {
2460  Block *block = info.first.first;
2461  Block *target = info.first.second;
2462  const BlockPhiInfo &phiInfo = info.second;
2463  LLVM_DEBUG({
2464  logger.startLine() << "[phi] block " << block << "\n";
2465  logger.startLine() << "[phi] before creating block argument:\n";
2466  block->getParentOp()->print(logger.getOStream());
2467  logger.startLine() << "\n";
2468  });
2469 
2470  // Set insertion point to before this block's terminator early because we
2471  // may materialize ops via getValue() call.
2472  auto *op = block->getTerminator();
2473  opBuilder.setInsertionPoint(op);
2474 
2475  SmallVector<Value, 4> blockArgs;
2476  blockArgs.reserve(phiInfo.size());
2477  for (uint32_t valueId : phiInfo) {
2478  if (Value value = getValue(valueId)) {
2479  blockArgs.push_back(value);
2480  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2481  << " id = " << valueId << "\n");
2482  } else {
2483  return emitError(unknownLoc, "OpPhi references undefined value!");
2484  }
2485  }
2486 
2487  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2488  // Replace the previous branch op with a new one with block arguments.
2489  spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2490  branchOp.getTarget(), blockArgs);
2491  branchOp.erase();
2492  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2493  assert((branchCondOp.getTrueBlock() == target ||
2494  branchCondOp.getFalseBlock() == target) &&
2495  "expected target to be either the true or false target");
2496  if (target == branchCondOp.getTrueTarget())
2497  spirv::BranchConditionalOp::create(
2498  opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2499  blockArgs, branchCondOp.getFalseBlockArguments(),
2500  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2501  branchCondOp.getFalseTarget());
2502  else
2503  spirv::BranchConditionalOp::create(
2504  opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2505  branchCondOp.getTrueBlockArguments(), blockArgs,
2506  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2507  branchCondOp.getFalseBlock());
2508 
2509  branchCondOp.erase();
2510  } else {
2511  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2512  }
2513 
2514  LLVM_DEBUG({
2515  logger.startLine() << "[phi] after creating block argument:\n";
2516  block->getParentOp()->print(logger.getOStream());
2517  logger.startLine() << "\n";
2518  });
2519  }
2520  blockPhiInfo.clear();
2521 
2522  LLVM_DEBUG({
2523  logger.unindent();
2524  logger.startLine()
2525  << "//--- [phi] completed wiring up block arguments ---//\n";
2526  });
2527  return success();
2528 }
2529 
2530 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2531  // Create a copy, so we can modify keys in the original.
2532  BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2533  for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2534  it != e; ++it) {
2535  auto &[block, mergeInfo] = *it;
2536 
2537  // Skip processing loop regions. For loop regions continueBlock is non-null.
2538  if (mergeInfo.continueBlock)
2539  continue;
2540 
2541  if (!block->mightHaveTerminator())
2542  continue;
2543 
2544  Operation *terminator = block->getTerminator();
2545  assert(terminator);
2546 
2547  if (!isa<spirv::BranchConditionalOp>(terminator))
2548  continue;
2549 
2550  // Check if the current header block is a merge block of another construct.
2551  bool splitHeaderMergeBlock = false;
2552  for (const auto &[_, mergeInfo] : blockMergeInfo) {
2553  if (mergeInfo.mergeBlock == block)
2554  splitHeaderMergeBlock = true;
2555  }
2556 
2557  // Do not split a block that only contains a conditional branch, unless it
2558  // is also a merge block of another construct - in that case we want to
2559  // split the block. We do not want two constructs to share header / merge
2560  // block.
2561  if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2562  Block *newBlock = block->splitBlock(terminator);
2563  OpBuilder builder(block, block->end());
2564  spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2565 
2566  // After splitting we need to update the map to use the new block as a
2567  // header.
2568  blockMergeInfo.erase(block);
2569  blockMergeInfo.try_emplace(newBlock, mergeInfo);
2570  }
2571  }
2572 
2573  return success();
2574 }
2575 
2576 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2577  if (!options.enableControlFlowStructurization) {
2578  LLVM_DEBUG(
2579  {
2580  logger.startLine()
2581  << "//----- [cf] skip structurizing control flow -----//\n";
2582  logger.indent();
2583  });
2584  return success();
2585  }
2586 
2587  LLVM_DEBUG({
2588  logger.startLine()
2589  << "//----- [cf] start structurizing control flow -----//\n";
2590  logger.indent();
2591  });
2592 
2593  LLVM_DEBUG({
2594  logger.startLine() << "[cf] split conditional blocks\n";
2595  logger.startLine() << "\n";
2596  });
2597 
2598  if (failed(splitConditionalBlocks())) {
2599  return failure();
2600  }
2601 
2602  // TODO: This loop is non-deterministic. Iteration order may vary between runs
2603  // for the same shader as the key to the map is a pointer. See:
2604  // https://github.com/llvm/llvm-project/issues/128547
2605  while (!blockMergeInfo.empty()) {
2606  Block *headerBlock = blockMergeInfo.begin()->first;
2607  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2608 
2609  LLVM_DEBUG({
2610  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2611  headerBlock->print(logger.getOStream());
2612  logger.startLine() << "\n";
2613  });
2614 
2615  auto *mergeBlock = mergeInfo.mergeBlock;
2616  assert(mergeBlock && "merge block cannot be nullptr");
2617  if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2618  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2619  LLVM_DEBUG({
2620  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2621  mergeBlock->print(logger.getOStream());
2622  logger.startLine() << "\n";
2623  });
2624 
2625  auto *continueBlock = mergeInfo.continueBlock;
2626  LLVM_DEBUG(if (continueBlock) {
2627  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2628  continueBlock->print(logger.getOStream());
2629  logger.startLine() << "\n";
2630  });
2631  // Erase this case before calling into structurizer, who will update
2632  // blockMergeInfo.
2633  blockMergeInfo.erase(blockMergeInfo.begin());
2634  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2635  blockMergeInfo, headerBlock,
2636  mergeBlock, continueBlock
2637 #ifndef NDEBUG
2638  ,
2639  logger
2640 #endif
2641  );
2642  if (failed(structurizer.structurize()))
2643  return failure();
2644  }
2645 
2646  LLVM_DEBUG({
2647  logger.unindent();
2648  logger.startLine()
2649  << "//--- [cf] completed structurizing control flow ---//\n";
2650  });
2651  return success();
2652 }
2653 
2654 //===----------------------------------------------------------------------===//
2655 // Debug
2656 //===----------------------------------------------------------------------===//
2657 
2658 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2659  if (!debugLine)
2660  return unknownLoc;
2661 
2662  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2663  if (fileName.empty())
2664  fileName = "<unknown>";
2665  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2666  debugLine->column);
2667 }
2668 
2669 LogicalResult
2670 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2671  // According to SPIR-V spec:
2672  // "This location information applies to the instructions physically
2673  // following this instruction, up to the first occurrence of any of the
2674  // following: the next end of block, the next OpLine instruction, or the next
2675  // OpNoLine instruction."
2676  if (operands.size() != 3)
2677  return emitError(unknownLoc, "OpLine must have 3 operands");
2678  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2679  return success();
2680 }
2681 
2682 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2683 
2684 LogicalResult
2685 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2686  if (operands.size() < 2)
2687  return emitError(unknownLoc, "OpString needs at least 2 operands");
2688 
2689  if (!debugInfoMap.lookup(operands[0]).empty())
2690  return emitError(unknownLoc,
2691  "duplicate debug string found for result <id> ")
2692  << operands[0];
2693 
2694  unsigned wordIndex = 1;
2695  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2696  if (wordIndex != operands.size())
2697  return emitError(unknownLoc,
2698  "unexpected trailing words in OpString instruction");
2699 
2700  debugInfoMap[operands[0]] = debugString;
2701  return success();
2702 }
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
int64_t rows
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:308
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
void print(raw_ostream &os)
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:250
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
void push_back(Operation *op)
Definition: Block.h:149
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:157
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:695
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockListType::iterator iterator
Definition: Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:255
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:170
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:504
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:793
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
Definition: Deserializer.h:61
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.