MLIR  21.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context,
54  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55  module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56 #ifndef NDEBUG
57  ,
58  logger(llvm::dbgs())
59 #endif
60 {
61 }
62 
64  LLVM_DEBUG({
65  logger.resetIndent();
66  logger.startLine()
67  << "//+++---------- start deserialization ----------+++//\n";
68  });
69 
70  if (failed(processHeader()))
71  return failure();
72 
73  spirv::Opcode opcode = spirv::Opcode::OpNop;
74  ArrayRef<uint32_t> operands;
75  auto binarySize = binary.size();
76  while (curOffset < binarySize) {
77  // Slice the next instruction out and populate `opcode` and `operands`.
78  // Internally this also updates `curOffset`.
79  if (failed(sliceInstruction(opcode, operands)))
80  return failure();
81 
82  if (failed(processInstruction(opcode, operands)))
83  return failure();
84  }
85 
86  assert(curOffset == binarySize &&
87  "deserializer should never index beyond the binary end");
88 
89  for (auto &deferred : deferredInstructions) {
90  if (failed(processInstruction(deferred.first, deferred.second, false))) {
91  return failure();
92  }
93  }
94 
95  attachVCETriple();
96 
97  LLVM_DEBUG(logger.startLine()
98  << "//+++-------- completed deserialization --------+++//\n");
99  return success();
100 }
101 
103  return std::move(module);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Module structure
108 //===----------------------------------------------------------------------===//
109 
110 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111  OpBuilder builder(context);
112  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113  spirv::ModuleOp::build(builder, state);
114  return cast<spirv::ModuleOp>(Operation::create(state));
115 }
116 
117 LogicalResult spirv::Deserializer::processHeader() {
118  if (binary.size() < spirv::kHeaderWordCount)
119  return emitError(unknownLoc,
120  "SPIR-V binary module must have a 5-word header");
121 
122  if (binary[0] != spirv::kMagicNumber)
123  return emitError(unknownLoc, "incorrect magic number");
124 
125  // Version number bytes: 0 | major number | minor number | 0
126  uint32_t majorVersion = (binary[1] << 8) >> 24;
127  uint32_t minorVersion = (binary[1] << 16) >> 24;
128  if (majorVersion == 1) {
129  switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
131  case v: \
132  version = spirv::Version::V_1_##v; \
133  break
134 
135  MIN_VERSION_CASE(0);
136  MIN_VERSION_CASE(1);
137  MIN_VERSION_CASE(2);
138  MIN_VERSION_CASE(3);
139  MIN_VERSION_CASE(4);
140  MIN_VERSION_CASE(5);
141 #undef MIN_VERSION_CASE
142  default:
143  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
144  << minorVersion;
145  }
146  } else {
147  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
148  << majorVersion;
149  }
150 
151  // TODO: generator number, bound, schema
152  curOffset = spirv::kHeaderWordCount;
153  return success();
154 }
155 
156 LogicalResult
157 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158  if (operands.size() != 1)
159  return emitError(unknownLoc, "OpCapability must have one parameter");
160 
161  auto cap = spirv::symbolizeCapability(operands[0]);
162  if (!cap)
163  return emitError(unknownLoc, "unknown capability: ") << operands[0];
164 
165  capabilities.insert(*cap);
166  return success();
167 }
168 
169 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170  if (words.empty()) {
171  return emitError(
172  unknownLoc,
173  "OpExtension must have a literal string for the extension name");
174  }
175 
176  unsigned wordIndex = 0;
177  StringRef extName = decodeStringLiteral(words, wordIndex);
178  if (wordIndex != words.size())
179  return emitError(unknownLoc,
180  "unexpected trailing words in OpExtension instruction");
181  auto ext = spirv::symbolizeExtension(extName);
182  if (!ext)
183  return emitError(unknownLoc, "unknown extension: ") << extName;
184 
185  extensions.insert(*ext);
186  return success();
187 }
188 
189 LogicalResult
190 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191  if (words.size() < 2) {
192  return emitError(unknownLoc,
193  "OpExtInstImport must have a result <id> and a literal "
194  "string for the extended instruction set name");
195  }
196 
197  unsigned wordIndex = 1;
198  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199  if (wordIndex != words.size()) {
200  return emitError(unknownLoc,
201  "unexpected trailing words in OpExtInstImport");
202  }
203  return success();
204 }
205 
206 void spirv::Deserializer::attachVCETriple() {
207  (*module)->setAttr(
208  spirv::ModuleOp::getVCETripleAttrName(),
209  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
210  extensions.getArrayRef(), context));
211 }
212 
213 LogicalResult
214 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215  if (operands.size() != 2)
216  return emitError(unknownLoc, "OpMemoryModel must have two operands");
217 
218  (*module)->setAttr(
219  module->getAddressingModelAttrName(),
220  opBuilder.getAttr<spirv::AddressingModelAttr>(
221  static_cast<spirv::AddressingModel>(operands.front())));
222 
223  (*module)->setAttr(module->getMemoryModelAttrName(),
224  opBuilder.getAttr<spirv::MemoryModelAttr>(
225  static_cast<spirv::MemoryModel>(operands.back())));
226 
227  return success();
228 }
229 
230 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
232  Location loc, OpBuilder &opBuilder,
234  StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
235  if (words.size() != 4) {
236  return emitError(loc, "OpDecoration with ")
237  << decorationName << "needs a cache control integer literal and a "
238  << cacheControlKind << " cache control literal";
239  }
240  unsigned cacheLevel = words[2];
241  auto cacheControlAttr = static_cast<EnumTy>(words[3]);
242  auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
244  if (auto attrList =
245  llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
246  llvm::append_range(attrs, attrList);
247  attrs.push_back(value);
248  decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
249  return success();
250 }
251 
252 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
253  // TODO: This function should also be auto-generated. For now, since only a
254  // few decorations are processed/handled in a meaningful manner, going with a
255  // manual implementation.
256  if (words.size() < 2) {
257  return emitError(
258  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
259  }
260  auto decorationName =
261  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
262  if (decorationName.empty()) {
263  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
264  }
265  auto symbol = getSymbolDecoration(decorationName);
266  switch (static_cast<spirv::Decoration>(words[1])) {
267  case spirv::Decoration::FPFastMathMode:
268  if (words.size() != 3) {
269  return emitError(unknownLoc, "OpDecorate with ")
270  << decorationName << " needs a single integer literal";
271  }
272  decorations[words[0]].set(
273  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
274  static_cast<FPFastMathMode>(words[2])));
275  break;
276  case spirv::Decoration::FPRoundingMode:
277  if (words.size() != 3) {
278  return emitError(unknownLoc, "OpDecorate with ")
279  << decorationName << " needs a single integer literal";
280  }
281  decorations[words[0]].set(
282  symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
283  static_cast<FPRoundingMode>(words[2])));
284  break;
285  case spirv::Decoration::DescriptorSet:
286  case spirv::Decoration::Binding:
287  if (words.size() != 3) {
288  return emitError(unknownLoc, "OpDecorate with ")
289  << decorationName << " needs a single integer literal";
290  }
291  decorations[words[0]].set(
292  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
293  break;
294  case spirv::Decoration::BuiltIn:
295  if (words.size() != 3) {
296  return emitError(unknownLoc, "OpDecorate with ")
297  << decorationName << " needs a single integer literal";
298  }
299  decorations[words[0]].set(
300  symbol, opBuilder.getStringAttr(
301  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
302  break;
303  case spirv::Decoration::ArrayStride:
304  if (words.size() != 3) {
305  return emitError(unknownLoc, "OpDecorate with ")
306  << decorationName << " needs a single integer literal";
307  }
308  typeDecorations[words[0]] = words[2];
309  break;
310  case spirv::Decoration::LinkageAttributes: {
311  if (words.size() < 4) {
312  return emitError(unknownLoc, "OpDecorate with ")
313  << decorationName
314  << " needs at least 1 string and 1 integer literal";
315  }
316  // LinkageAttributes has two parameters ["linkageName", linkageType]
317  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
318  // "linkageName" is a stringliteral encoded as uint32_t,
319  // hence the size of name is variable length which results in words.size()
320  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
321  // 3 + ceildiv(strlen(name), 4).
322  unsigned wordIndex = 2;
323  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
324  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
325  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
326  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
327  StringAttr::get(context, linkageName), linkageTypeAttr);
328  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
329  break;
330  }
331  case spirv::Decoration::Aliased:
332  case spirv::Decoration::AliasedPointer:
333  case spirv::Decoration::Block:
334  case spirv::Decoration::BufferBlock:
335  case spirv::Decoration::Flat:
336  case spirv::Decoration::NonReadable:
337  case spirv::Decoration::NonWritable:
338  case spirv::Decoration::NoPerspective:
339  case spirv::Decoration::NoSignedWrap:
340  case spirv::Decoration::NoUnsignedWrap:
341  case spirv::Decoration::RelaxedPrecision:
342  case spirv::Decoration::Restrict:
343  case spirv::Decoration::RestrictPointer:
344  case spirv::Decoration::NoContraction:
345  case spirv::Decoration::Constant:
346  if (words.size() != 2) {
347  return emitError(unknownLoc, "OpDecoration with ")
348  << decorationName << "needs a single target <id>";
349  }
350  // Block decoration does not affect spirv.struct type, but is still stored
351  // for verification.
352  // TODO: Update StructType to contain this information since
353  // it is needed for many validation rules.
354  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
355  break;
356  case spirv::Decoration::Location:
357  case spirv::Decoration::SpecId:
358  if (words.size() != 3) {
359  return emitError(unknownLoc, "OpDecoration with ")
360  << decorationName << "needs a single integer literal";
361  }
362  decorations[words[0]].set(
363  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
364  break;
365  case spirv::Decoration::CacheControlLoadINTEL: {
366  LogicalResult res = deserializeCacheControlDecoration<
367  CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
369  "load");
370  if (failed(res))
371  return res;
372  break;
373  }
374  case spirv::Decoration::CacheControlStoreINTEL: {
375  LogicalResult res = deserializeCacheControlDecoration<
376  CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
378  "store");
379  if (failed(res))
380  return res;
381  break;
382  }
383  default:
384  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
385  }
386  return success();
387 }
388 
389 LogicalResult
390 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
391  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
392  if (words.size() < 3) {
393  return emitError(unknownLoc,
394  "OpMemberDecorate must have at least 3 operands");
395  }
396 
397  auto decoration = static_cast<spirv::Decoration>(words[2]);
398  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
399  return emitError(unknownLoc,
400  " missing offset specification in OpMemberDecorate with "
401  "Offset decoration");
402  }
403  ArrayRef<uint32_t> decorationOperands;
404  if (words.size() > 3) {
405  decorationOperands = words.slice(3);
406  }
407  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
408  return success();
409 }
410 
411 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
412  if (words.size() < 3) {
413  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
414  }
415  unsigned wordIndex = 2;
416  auto name = decodeStringLiteral(words, wordIndex);
417  if (wordIndex != words.size()) {
418  return emitError(unknownLoc,
419  "unexpected trailing words in OpMemberName instruction");
420  }
421  memberNameMap[words[0]][words[1]] = name;
422  return success();
423 }
424 
425 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
426  uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
427  if (!decorations.contains(argID)) {
428  argAttrs[argIndex] = DictionaryAttr::get(context, {});
429  return success();
430  }
431 
432  spirv::DecorationAttr foundDecorationAttr;
433  for (NamedAttribute decAttr : decorations[argID]) {
434  for (auto decoration :
435  {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436  spirv::Decoration::AliasedPointer,
437  spirv::Decoration::RestrictPointer}) {
438 
439  if (decAttr.getName() !=
440  getSymbolDecoration(stringifyDecoration(decoration)))
441  continue;
442 
443  if (foundDecorationAttr)
444  return emitError(unknownLoc,
445  "more than one Aliased/Restrict decorations for "
446  "function argument with result <id> ")
447  << argID;
448 
449  foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
450  break;
451  }
452 
453  if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
454  spirv::Decoration::RelaxedPrecision))) {
455  // TODO: Current implementation supports only one decoration per function
456  // parameter so RelaxedPrecision cannot be applied at the same time as,
457  // for example, Aliased/Restrict/etc. This should be relaxed to allow any
458  // combination of decoration allowed by the spec to be supported.
459  if (foundDecorationAttr)
460  return emitError(unknownLoc, "already found a decoration for function "
461  "argument with result <id> ")
462  << argID;
463 
464  foundDecorationAttr = spirv::DecorationAttr::get(
465  context, spirv::Decoration::RelaxedPrecision);
466  }
467  }
468 
469  if (!foundDecorationAttr)
470  return emitError(unknownLoc, "unimplemented decoration support for "
471  "function argument with result <id> ")
472  << argID;
473 
474  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
475  foundDecorationAttr);
476  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
477  return success();
478 }
479 
480 LogicalResult
481 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
482  if (curFunction) {
483  return emitError(unknownLoc, "found function inside function");
484  }
485 
486  // Get the result type
487  if (operands.size() != 4) {
488  return emitError(unknownLoc, "OpFunction must have 4 parameters");
489  }
490  Type resultType = getType(operands[0]);
491  if (!resultType) {
492  return emitError(unknownLoc, "undefined result type from <id> ")
493  << operands[0];
494  }
495 
496  uint32_t fnID = operands[1];
497  if (funcMap.count(fnID)) {
498  return emitError(unknownLoc, "duplicate function definition/declaration");
499  }
500 
501  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
502  if (!fnControl) {
503  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
504  }
505 
506  Type fnType = getType(operands[3]);
507  if (!fnType || !isa<FunctionType>(fnType)) {
508  return emitError(unknownLoc, "unknown function type from <id> ")
509  << operands[3];
510  }
511  auto functionType = cast<FunctionType>(fnType);
512 
513  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
514  (functionType.getNumResults() == 1 &&
515  functionType.getResult(0) != resultType)) {
516  return emitError(unknownLoc, "mismatch in function type ")
517  << functionType << " and return type " << resultType << " specified";
518  }
519 
520  std::string fnName = getFunctionSymbol(fnID);
521  auto funcOp = opBuilder.create<spirv::FuncOp>(
522  unknownLoc, fnName, functionType, fnControl.value());
523  // Processing other function attributes.
524  if (decorations.count(fnID)) {
525  for (auto attr : decorations[fnID].getAttrs()) {
526  funcOp->setAttr(attr.getName(), attr.getValue());
527  }
528  }
529  curFunction = funcMap[fnID] = funcOp;
530  auto *entryBlock = funcOp.addEntryBlock();
531  LLVM_DEBUG({
532  logger.startLine()
533  << "//===-------------------------------------------===//\n";
534  logger.startLine() << "[fn] name: " << fnName << "\n";
535  logger.startLine() << "[fn] type: " << fnType << "\n";
536  logger.startLine() << "[fn] ID: " << fnID << "\n";
537  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
538  logger.indent();
539  });
540 
541  SmallVector<Attribute> argAttrs;
542  argAttrs.resize(functionType.getNumInputs());
543 
544  // Parse the op argument instructions
545  if (functionType.getNumInputs()) {
546  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547  auto argType = functionType.getInput(i);
548  spirv::Opcode opcode = spirv::Opcode::OpNop;
549  ArrayRef<uint32_t> operands;
550  if (failed(sliceInstruction(opcode, operands,
551  spirv::Opcode::OpFunctionParameter))) {
552  return failure();
553  }
554  if (opcode != spirv::Opcode::OpFunctionParameter) {
555  return emitError(
556  unknownLoc,
557  "missing OpFunctionParameter instruction for argument ")
558  << i;
559  }
560  if (operands.size() != 2) {
561  return emitError(
562  unknownLoc,
563  "expected result type and result <id> for OpFunctionParameter");
564  }
565  auto argDefinedType = getType(operands[0]);
566  if (!argDefinedType || argDefinedType != argType) {
567  return emitError(unknownLoc,
568  "mismatch in argument type between function type "
569  "definition ")
570  << functionType << " and argument type definition "
571  << argDefinedType << " at argument " << i;
572  }
573  if (getValue(operands[1])) {
574  return emitError(unknownLoc, "duplicate definition of result <id> ")
575  << operands[1];
576  }
577  if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
578  return failure();
579  }
580 
581  auto argValue = funcOp.getArgument(i);
582  valueMap[operands[1]] = argValue;
583  }
584  }
585 
586  if (llvm::any_of(argAttrs, [](Attribute attr) {
587  auto argAttr = cast<DictionaryAttr>(attr);
588  return !argAttr.empty();
589  }))
590  funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
591 
592  // entryBlock is needed to access the arguments, Once that is done, we can
593  // erase the block for functions with 'Import' LinkageAttributes, since these
594  // are essentially function declarations, so they have no body.
595  auto linkageAttr = funcOp.getLinkageAttributes();
596  auto hasImportLinkage =
597  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
598  spirv::LinkageType::Import);
599  if (hasImportLinkage)
600  funcOp.eraseBody();
601 
602  // RAII guard to reset the insertion point to the module's region after
603  // deserializing the body of this function.
604  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
605 
606  spirv::Opcode opcode = spirv::Opcode::OpNop;
607  ArrayRef<uint32_t> instOperands;
608 
609  // Special handling for the entry block. We need to make sure it starts with
610  // an OpLabel instruction. The entry block takes the same parameters as the
611  // function. All other blocks do not take any parameter. We have already
612  // created the entry block, here we need to register it to the correct label
613  // <id>.
614  if (failed(sliceInstruction(opcode, instOperands,
615  spirv::Opcode::OpFunctionEnd))) {
616  return failure();
617  }
618  if (opcode == spirv::Opcode::OpFunctionEnd) {
619  return processFunctionEnd(instOperands);
620  }
621  if (opcode != spirv::Opcode::OpLabel) {
622  return emitError(unknownLoc, "a basic block must start with OpLabel");
623  }
624  if (instOperands.size() != 1) {
625  return emitError(unknownLoc, "OpLabel should only have result <id>");
626  }
627  blockMap[instOperands[0]] = entryBlock;
628  if (failed(processLabel(instOperands))) {
629  return failure();
630  }
631 
632  // Then process all the other instructions in the function until we hit
633  // OpFunctionEnd.
634  while (succeeded(sliceInstruction(opcode, instOperands,
635  spirv::Opcode::OpFunctionEnd)) &&
636  opcode != spirv::Opcode::OpFunctionEnd) {
637  if (failed(processInstruction(opcode, instOperands))) {
638  return failure();
639  }
640  }
641  if (opcode != spirv::Opcode::OpFunctionEnd) {
642  return failure();
643  }
644 
645  return processFunctionEnd(instOperands);
646 }
647 
648 LogicalResult
649 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
650  // Process OpFunctionEnd.
651  if (!operands.empty()) {
652  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
653  }
654 
655  // Wire up block arguments from OpPhi instructions.
656  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
657  // ops.
658  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
659  return failure();
660  }
661 
662  curBlock = nullptr;
663  curFunction = std::nullopt;
664 
665  LLVM_DEBUG({
666  logger.unindent();
667  logger.startLine()
668  << "//===-------------------------------------------===//\n";
669  });
670  return success();
671 }
672 
673 std::optional<std::pair<Attribute, Type>>
674 spirv::Deserializer::getConstant(uint32_t id) {
675  auto constIt = constantMap.find(id);
676  if (constIt == constantMap.end())
677  return std::nullopt;
678  return constIt->getSecond();
679 }
680 
681 std::optional<spirv::SpecConstOperationMaterializationInfo>
682 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
683  auto constIt = specConstOperationMap.find(id);
684  if (constIt == specConstOperationMap.end())
685  return std::nullopt;
686  return constIt->getSecond();
687 }
688 
689 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
690  auto funcName = nameMap.lookup(id).str();
691  if (funcName.empty()) {
692  funcName = "spirv_fn_" + std::to_string(id);
693  }
694  return funcName;
695 }
696 
697 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
698  auto constName = nameMap.lookup(id).str();
699  if (constName.empty()) {
700  constName = "spirv_spec_const_" + std::to_string(id);
701  }
702  return constName;
703 }
704 
705 spirv::SpecConstantOp
706 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
707  TypedAttr defaultValue) {
708  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
709  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
710  defaultValue);
711  if (decorations.count(resultID)) {
712  for (auto attr : decorations[resultID].getAttrs())
713  op->setAttr(attr.getName(), attr.getValue());
714  }
715  specConstMap[resultID] = op;
716  return op;
717 }
718 
719 LogicalResult
720 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
721  unsigned wordIndex = 0;
722  if (operands.size() < 3) {
723  return emitError(
724  unknownLoc,
725  "OpVariable needs at least 3 operands, type, <id> and storage class");
726  }
727 
728  // Result Type.
729  auto type = getType(operands[wordIndex]);
730  if (!type) {
731  return emitError(unknownLoc, "unknown result type <id> : ")
732  << operands[wordIndex];
733  }
734  auto ptrType = dyn_cast<spirv::PointerType>(type);
735  if (!ptrType) {
736  return emitError(unknownLoc,
737  "expected a result type <id> to be a spirv.ptr, found : ")
738  << type;
739  }
740  wordIndex++;
741 
742  // Result <id>.
743  auto variableID = operands[wordIndex];
744  auto variableName = nameMap.lookup(variableID).str();
745  if (variableName.empty()) {
746  variableName = "spirv_var_" + std::to_string(variableID);
747  }
748  wordIndex++;
749 
750  // Storage class.
751  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
752  if (ptrType.getStorageClass() != storageClass) {
753  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
754  << type << " and that specified in OpVariable instruction : "
755  << stringifyStorageClass(storageClass);
756  }
757  wordIndex++;
758 
759  // Initializer.
760  FlatSymbolRefAttr initializer = nullptr;
761 
762  if (wordIndex < operands.size()) {
763  Operation *op = nullptr;
764 
765  if (auto initOp = getGlobalVariable(operands[wordIndex]))
766  op = initOp;
767  else if (auto initOp = getSpecConstant(operands[wordIndex]))
768  op = initOp;
769  else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
770  op = initOp;
771  else
772  return emitError(unknownLoc, "unknown <id> ")
773  << operands[wordIndex] << "used as initializer";
774 
775  initializer = SymbolRefAttr::get(op);
776  wordIndex++;
777  }
778  if (wordIndex != operands.size()) {
779  return emitError(unknownLoc,
780  "found more operands than expected when deserializing "
781  "OpVariable instruction, only ")
782  << wordIndex << " of " << operands.size() << " processed";
783  }
784  auto loc = createFileLineColLoc(opBuilder);
785  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
786  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
787  initializer);
788 
789  // Decorations.
790  if (decorations.count(variableID)) {
791  for (auto attr : decorations[variableID].getAttrs())
792  varOp->setAttr(attr.getName(), attr.getValue());
793  }
794  globalVariableMap[variableID] = varOp;
795  return success();
796 }
797 
798 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
799  auto constInfo = getConstant(id);
800  if (!constInfo) {
801  return nullptr;
802  }
803  return dyn_cast<IntegerAttr>(constInfo->first);
804 }
805 
806 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
807  if (operands.size() < 2) {
808  return emitError(unknownLoc, "OpName needs at least 2 operands");
809  }
810  if (!nameMap.lookup(operands[0]).empty()) {
811  return emitError(unknownLoc, "duplicate name found for result <id> ")
812  << operands[0];
813  }
814  unsigned wordIndex = 1;
815  StringRef name = decodeStringLiteral(operands, wordIndex);
816  if (wordIndex != operands.size()) {
817  return emitError(unknownLoc,
818  "unexpected trailing words in OpName instruction");
819  }
820  nameMap[operands[0]] = name;
821  return success();
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Type
826 //===----------------------------------------------------------------------===//
827 
828 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
829  ArrayRef<uint32_t> operands) {
830  if (operands.empty()) {
831  return emitError(unknownLoc, "type instruction with opcode ")
832  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
833  }
834 
835  /// TODO: Types might be forward declared in some instructions and need to be
836  /// handled appropriately.
837  if (typeMap.count(operands[0])) {
838  return emitError(unknownLoc, "duplicate definition for result <id> ")
839  << operands[0];
840  }
841 
842  switch (opcode) {
843  case spirv::Opcode::OpTypeVoid:
844  if (operands.size() != 1)
845  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
846  typeMap[operands[0]] = opBuilder.getNoneType();
847  break;
848  case spirv::Opcode::OpTypeBool:
849  if (operands.size() != 1)
850  return emitError(unknownLoc, "OpTypeBool must have no parameters");
851  typeMap[operands[0]] = opBuilder.getI1Type();
852  break;
853  case spirv::Opcode::OpTypeInt: {
854  if (operands.size() != 3)
855  return emitError(
856  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
857 
858  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
859  // to preserve or validate.
860  // 0 indicates unsigned, or no signedness semantics
861  // 1 indicates signed semantics."
862  //
863  // So we cannot differentiate signless and unsigned integers; always use
864  // signless semantics for such cases.
865  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
866  : IntegerType::SignednessSemantics::Signless;
867  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
868  } break;
869  case spirv::Opcode::OpTypeFloat: {
870  if (operands.size() != 2 && operands.size() != 3)
871  return emitError(unknownLoc,
872  "OpTypeFloat expects either 2 operands (type, bitwidth) "
873  "or 3 operands (type, bitwidth, encoding), but got ")
874  << operands.size();
875  uint32_t bitWidth = operands[1];
876 
877  Type floatTy;
878  switch (bitWidth) {
879  case 16:
880  floatTy = opBuilder.getF16Type();
881  break;
882  case 32:
883  floatTy = opBuilder.getF32Type();
884  break;
885  case 64:
886  floatTy = opBuilder.getF64Type();
887  break;
888  default:
889  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
890  << bitWidth;
891  }
892 
893  if (operands.size() == 3) {
894  if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
895  return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
896  << operands[2];
897  if (bitWidth != 16)
898  return emitError(unknownLoc,
899  "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
900  << bitWidth << " (expected 16)";
901  floatTy = opBuilder.getBF16Type();
902  }
903 
904  typeMap[operands[0]] = floatTy;
905  } break;
906  case spirv::Opcode::OpTypeVector: {
907  if (operands.size() != 3) {
908  return emitError(
909  unknownLoc,
910  "OpTypeVector must have element type and count parameters");
911  }
912  Type elementTy = getType(operands[1]);
913  if (!elementTy) {
914  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
915  << operands[1];
916  }
917  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
918  } break;
919  case spirv::Opcode::OpTypePointer: {
920  return processOpTypePointer(operands);
921  } break;
922  case spirv::Opcode::OpTypeArray:
923  return processArrayType(operands);
924  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
925  return processCooperativeMatrixTypeKHR(operands);
926  case spirv::Opcode::OpTypeFunction:
927  return processFunctionType(operands);
928  case spirv::Opcode::OpTypeImage:
929  return processImageType(operands);
930  case spirv::Opcode::OpTypeSampledImage:
931  return processSampledImageType(operands);
932  case spirv::Opcode::OpTypeRuntimeArray:
933  return processRuntimeArrayType(operands);
934  case spirv::Opcode::OpTypeStruct:
935  return processStructType(operands);
936  case spirv::Opcode::OpTypeMatrix:
937  return processMatrixType(operands);
938  case spirv::Opcode::OpTypeTensorARM:
939  return processTensorARMType(operands);
940  default:
941  return emitError(unknownLoc, "unhandled type instruction");
942  }
943  return success();
944 }
945 
946 LogicalResult
947 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
948  if (operands.size() != 3)
949  return emitError(unknownLoc, "OpTypePointer must have two parameters");
950 
951  auto pointeeType = getType(operands[2]);
952  if (!pointeeType)
953  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
954  << operands[2];
955 
956  uint32_t typePointerID = operands[0];
957  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
958  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
959 
960  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
961  deferredStructIt != std::end(deferredStructTypesInfos);) {
962  for (auto *unresolvedMemberIt =
963  std::begin(deferredStructIt->unresolvedMemberTypes);
964  unresolvedMemberIt !=
965  std::end(deferredStructIt->unresolvedMemberTypes);) {
966  if (unresolvedMemberIt->first == typePointerID) {
967  // The newly constructed pointer type can resolve one of the
968  // deferred struct type members; update the memberTypes list and
969  // clean the unresolvedMemberTypes list accordingly.
970  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
971  typeMap[typePointerID];
972  unresolvedMemberIt =
973  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
974  } else {
975  ++unresolvedMemberIt;
976  }
977  }
978 
979  if (deferredStructIt->unresolvedMemberTypes.empty()) {
980  // All deferred struct type members are now resolved, set the struct body.
981  auto structType = deferredStructIt->deferredStructType;
982 
983  assert(structType && "expected a spirv::StructType");
984  assert(structType.isIdentified() && "expected an indentified struct");
985 
986  if (failed(structType.trySetBody(
987  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
988  deferredStructIt->memberDecorationsInfo)))
989  return failure();
990 
991  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
992  } else {
993  ++deferredStructIt;
994  }
995  }
996 
997  return success();
998 }
999 
1000 LogicalResult
1001 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
1002  if (operands.size() != 3) {
1003  return emitError(unknownLoc,
1004  "OpTypeArray must have element type and count parameters");
1005  }
1006 
1007  Type elementTy = getType(operands[1]);
1008  if (!elementTy) {
1009  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1010  << operands[1];
1011  }
1012 
1013  unsigned count = 0;
1014  // TODO: The count can also come frome a specialization constant.
1015  auto countInfo = getConstant(operands[2]);
1016  if (!countInfo) {
1017  return emitError(unknownLoc, "OpTypeArray count <id> ")
1018  << operands[2] << "can only come from normal constant right now";
1019  }
1020 
1021  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1022  count = intVal.getValue().getZExtValue();
1023  } else {
1024  return emitError(unknownLoc, "OpTypeArray count must come from a "
1025  "scalar integer constant instruction");
1026  }
1027 
1028  typeMap[operands[0]] = spirv::ArrayType::get(
1029  elementTy, count, typeDecorations.lookup(operands[0]));
1030  return success();
1031 }
1032 
1033 LogicalResult
1034 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1035  assert(!operands.empty() && "No operands for processing function type");
1036  if (operands.size() == 1) {
1037  return emitError(unknownLoc, "missing return type for OpTypeFunction");
1038  }
1039  auto returnType = getType(operands[1]);
1040  if (!returnType) {
1041  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1042  }
1043  SmallVector<Type, 1> argTypes;
1044  for (size_t i = 2, e = operands.size(); i < e; ++i) {
1045  auto ty = getType(operands[i]);
1046  if (!ty) {
1047  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1048  }
1049  argTypes.push_back(ty);
1050  }
1051  ArrayRef<Type> returnTypes;
1052  if (!isVoidType(returnType)) {
1053  returnTypes = llvm::ArrayRef(returnType);
1054  }
1055  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1056  return success();
1057 }
1058 
1059 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1060  ArrayRef<uint32_t> operands) {
1061  if (operands.size() != 6) {
1062  return emitError(unknownLoc,
1063  "OpTypeCooperativeMatrixKHR must have element type, "
1064  "scope, row and column parameters, and use");
1065  }
1066 
1067  Type elementTy = getType(operands[1]);
1068  if (!elementTy) {
1069  return emitError(unknownLoc,
1070  "OpTypeCooperativeMatrixKHR references undefined <id> ")
1071  << operands[1];
1072  }
1073 
1074  std::optional<spirv::Scope> scope =
1075  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1076  if (!scope) {
1077  return emitError(
1078  unknownLoc,
1079  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1080  << operands[2];
1081  }
1082 
1083  IntegerAttr rowsAttr = getConstantInt(operands[3]);
1084  IntegerAttr columnsAttr = getConstantInt(operands[4]);
1085  IntegerAttr useAttr = getConstantInt(operands[5]);
1086 
1087  if (!rowsAttr)
1088  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1089  "undefined constant <id> ")
1090  << operands[3];
1091 
1092  if (!columnsAttr)
1093  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1094  "references undefined constant <id> ")
1095  << operands[4];
1096 
1097  if (!useAttr)
1098  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1099  "undefined constant <id> ")
1100  << operands[5];
1101 
1102  unsigned rows = rowsAttr.getInt();
1103  unsigned columns = columnsAttr.getInt();
1104 
1105  std::optional<spirv::CooperativeMatrixUseKHR> use =
1106  spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1107  if (!use) {
1108  return emitError(
1109  unknownLoc,
1110  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1111  << operands[5];
1112  }
1113 
1114  typeMap[operands[0]] =
1115  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1116  return success();
1117 }
1118 
1119 LogicalResult
1120 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1121  if (operands.size() != 2) {
1122  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1123  }
1124  Type memberType = getType(operands[1]);
1125  if (!memberType) {
1126  return emitError(unknownLoc,
1127  "OpTypeRuntimeArray references undefined <id> ")
1128  << operands[1];
1129  }
1130  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1131  memberType, typeDecorations.lookup(operands[0]));
1132  return success();
1133 }
1134 
1135 LogicalResult
1136 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1137  // TODO: Find a way to handle identified structs when debug info is stripped.
1138 
1139  if (operands.empty()) {
1140  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1141  }
1142 
1143  if (operands.size() == 1) {
1144  // Handle empty struct.
1145  typeMap[operands[0]] =
1146  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1147  return success();
1148  }
1149 
1150  // First element is operand ID, second element is member index in the struct.
1151  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1152  SmallVector<Type, 4> memberTypes;
1153 
1154  for (auto op : llvm::drop_begin(operands, 1)) {
1155  Type memberType = getType(op);
1156  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1157 
1158  if (!memberType && !typeForwardPtr)
1159  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1160  << op;
1161 
1162  if (!memberType)
1163  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1164 
1165  memberTypes.push_back(memberType);
1166  }
1167 
1170  if (memberDecorationMap.count(operands[0])) {
1171  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1172  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1173  if (allMemberDecorations.count(memberIndex)) {
1174  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1175  // Check for offset.
1176  if (memberDecoration.first == spirv::Decoration::Offset) {
1177  // If offset info is empty, resize to the number of members;
1178  if (offsetInfo.empty()) {
1179  offsetInfo.resize(memberTypes.size());
1180  }
1181  offsetInfo[memberIndex] = memberDecoration.second[0];
1182  } else {
1183  if (!memberDecoration.second.empty()) {
1184  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1185  memberDecoration.first,
1186  memberDecoration.second[0]);
1187  } else {
1188  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1189  memberDecoration.first, 0);
1190  }
1191  }
1192  }
1193  }
1194  }
1195  }
1196 
1197  uint32_t structID = operands[0];
1198  std::string structIdentifier = nameMap.lookup(structID).str();
1199 
1200  if (structIdentifier.empty()) {
1201  assert(unresolvedMemberTypes.empty() &&
1202  "didn't expect unresolved member types");
1203  typeMap[structID] =
1204  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1205  } else {
1206  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1207  typeMap[structID] = structTy;
1208 
1209  if (!unresolvedMemberTypes.empty())
1210  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1211  memberTypes, offsetInfo,
1212  memberDecorationsInfo});
1213  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1214  memberDecorationsInfo)))
1215  return failure();
1216  }
1217 
1218  // TODO: Update StructType to have member name as attribute as
1219  // well.
1220  return success();
1221 }
1222 
1223 LogicalResult
1224 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1225  if (operands.size() != 3) {
1226  // Three operands are needed: result_id, column_type, and column_count
1227  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1228  " (result_id, column_type, and column_count)");
1229  }
1230  // Matrix columns must be of vector type
1231  Type elementTy = getType(operands[1]);
1232  if (!elementTy) {
1233  return emitError(unknownLoc,
1234  "OpTypeMatrix references undefined column type.")
1235  << operands[1];
1236  }
1237 
1238  uint32_t colsCount = operands[2];
1239  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1240  return success();
1241 }
1242 
1243 LogicalResult
1244 spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
1245  unsigned size = operands.size();
1246  if (size < 2 || size > 4)
1247  return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
1248  "(result_id, element_type, (rank), (shape)) ")
1249  << size;
1250 
1251  Type elementTy = getType(operands[1]);
1252  if (!elementTy)
1253  return emitError(unknownLoc,
1254  "OpTypeTensorARM references undefined element type ")
1255  << operands[1];
1256 
1257  if (size == 2) {
1258  typeMap[operands[0]] = TensorArmType::get({}, elementTy);
1259  return success();
1260  }
1261 
1262  IntegerAttr rankAttr = getConstantInt(operands[2]);
1263  if (!rankAttr)
1264  return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
1265  "scalar integer constant instruction");
1266  unsigned rank = rankAttr.getValue().getZExtValue();
1267  if (size == 3) {
1268  SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1269  typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1270  return success();
1271  }
1272 
1273  std::optional<std::pair<Attribute, Type>> shapeInfo =
1274  getConstant(operands[3]);
1275  if (!shapeInfo)
1276  return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
1277  "constant instruction of type OpTypeArray");
1278 
1279  ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
1281  for (auto dimAttr : shapeArrayAttr.getValue()) {
1282  auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
1283  if (!dimIntAttr)
1284  return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
1285  "dimension size");
1286  shape.push_back(dimIntAttr.getValue().getSExtValue());
1287  }
1288  typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1289  return success();
1290 }
1291 
1292 LogicalResult
1293 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1294  if (operands.size() != 2)
1295  return emitError(unknownLoc,
1296  "OpTypeForwardPointer instruction must have two operands");
1297 
1298  typeForwardPointerIDs.insert(operands[0]);
1299  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1300  // instruction that defines the actual type.
1301 
1302  return success();
1303 }
1304 
1305 LogicalResult
1306 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1307  // TODO: Add support for Access Qualifier.
1308  if (operands.size() != 8)
1309  return emitError(
1310  unknownLoc,
1311  "OpTypeImage with non-eight operands are not supported yet");
1312 
1313  Type elementTy = getType(operands[1]);
1314  if (!elementTy)
1315  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1316  << operands[1];
1317 
1318  auto dim = spirv::symbolizeDim(operands[2]);
1319  if (!dim)
1320  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1321  << operands[2];
1322 
1323  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1324  if (!depthInfo)
1325  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1326  << operands[3];
1327 
1328  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1329  if (!arrayedInfo)
1330  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1331  << operands[4];
1332 
1333  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1334  if (!samplingInfo)
1335  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1336 
1337  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1338  if (!samplerUseInfo)
1339  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1340  << operands[6];
1341 
1342  auto format = spirv::symbolizeImageFormat(operands[7]);
1343  if (!format)
1344  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1345  << operands[7];
1346 
1347  typeMap[operands[0]] = spirv::ImageType::get(
1348  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1349  samplingInfo.value(), samplerUseInfo.value(), format.value());
1350  return success();
1351 }
1352 
1353 LogicalResult
1354 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1355  if (operands.size() != 2)
1356  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1357 
1358  Type elementTy = getType(operands[1]);
1359  if (!elementTy)
1360  return emitError(unknownLoc,
1361  "OpTypeSampledImage references undefined <id>: ")
1362  << operands[1];
1363 
1364  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1365  return success();
1366 }
1367 
1368 //===----------------------------------------------------------------------===//
1369 // Constant
1370 //===----------------------------------------------------------------------===//
1371 
1372 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1373  bool isSpec) {
1374  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1375 
1376  if (operands.size() < 2) {
1377  return emitError(unknownLoc)
1378  << opname << " must have type <id> and result <id>";
1379  }
1380  if (operands.size() < 3) {
1381  return emitError(unknownLoc)
1382  << opname << " must have at least 1 more parameter";
1383  }
1384 
1385  Type resultType = getType(operands[0]);
1386  if (!resultType) {
1387  return emitError(unknownLoc, "undefined result type from <id> ")
1388  << operands[0];
1389  }
1390 
1391  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1392  if (bitwidth == 64) {
1393  if (operands.size() == 4) {
1394  return success();
1395  }
1396  return emitError(unknownLoc)
1397  << opname << " should have 2 parameters for 64-bit values";
1398  }
1399  if (bitwidth <= 32) {
1400  if (operands.size() == 3) {
1401  return success();
1402  }
1403 
1404  return emitError(unknownLoc)
1405  << opname
1406  << " should have 1 parameter for values with no more than 32 bits";
1407  }
1408  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1409  << bitwidth;
1410  };
1411 
1412  auto resultID = operands[1];
1413 
1414  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1415  auto bitwidth = intType.getWidth();
1416  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1417  return failure();
1418  }
1419 
1420  APInt value;
1421  if (bitwidth == 64) {
1422  // 64-bit integers are represented with two SPIR-V words. According to
1423  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1424  // literal’s low-order words appear first."
1425  struct DoubleWord {
1426  uint32_t word1;
1427  uint32_t word2;
1428  } words = {operands[2], operands[3]};
1429  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1430  } else if (bitwidth <= 32) {
1431  value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1432  /*implicitTrunc=*/true);
1433  }
1434 
1435  auto attr = opBuilder.getIntegerAttr(intType, value);
1436 
1437  if (isSpec) {
1438  createSpecConstant(unknownLoc, resultID, attr);
1439  } else {
1440  // For normal constants, we just record the attribute (and its type) for
1441  // later materialization at use sites.
1442  constantMap.try_emplace(resultID, attr, intType);
1443  }
1444 
1445  return success();
1446  }
1447 
1448  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1449  auto bitwidth = floatType.getWidth();
1450  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1451  return failure();
1452  }
1453 
1454  APFloat value(0.f);
1455  if (floatType.isF64()) {
1456  // Double values are represented with two SPIR-V words. According to
1457  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1458  // literal’s low-order words appear first."
1459  struct DoubleWord {
1460  uint32_t word1;
1461  uint32_t word2;
1462  } words = {operands[2], operands[3]};
1463  value = APFloat(llvm::bit_cast<double>(words));
1464  } else if (floatType.isF32()) {
1465  value = APFloat(llvm::bit_cast<float>(operands[2]));
1466  } else if (floatType.isF16()) {
1467  APInt data(16, operands[2]);
1468  value = APFloat(APFloat::IEEEhalf(), data);
1469  } else if (floatType.isBF16()) {
1470  APInt data(16, operands[2]);
1471  value = APFloat(APFloat::BFloat(), data);
1472  }
1473 
1474  auto attr = opBuilder.getFloatAttr(floatType, value);
1475  if (isSpec) {
1476  createSpecConstant(unknownLoc, resultID, attr);
1477  } else {
1478  // For normal constants, we just record the attribute (and its type) for
1479  // later materialization at use sites.
1480  constantMap.try_emplace(resultID, attr, floatType);
1481  }
1482 
1483  return success();
1484  }
1485 
1486  return emitError(unknownLoc, "OpConstant can only generate values of "
1487  "scalar integer or floating-point type");
1488 }
1489 
1490 LogicalResult spirv::Deserializer::processConstantBool(
1491  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1492  if (operands.size() != 2) {
1493  return emitError(unknownLoc, "Op")
1494  << (isSpec ? "Spec" : "") << "Constant"
1495  << (isTrue ? "True" : "False")
1496  << " must have type <id> and result <id>";
1497  }
1498 
1499  auto attr = opBuilder.getBoolAttr(isTrue);
1500  auto resultID = operands[1];
1501  if (isSpec) {
1502  createSpecConstant(unknownLoc, resultID, attr);
1503  } else {
1504  // For normal constants, we just record the attribute (and its type) for
1505  // later materialization at use sites.
1506  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1507  }
1508 
1509  return success();
1510 }
1511 
1512 LogicalResult
1513 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1514  if (operands.size() < 2) {
1515  return emitError(unknownLoc,
1516  "OpConstantComposite must have type <id> and result <id>");
1517  }
1518  if (operands.size() < 3) {
1519  return emitError(unknownLoc,
1520  "OpConstantComposite must have at least 1 parameter");
1521  }
1522 
1523  Type resultType = getType(operands[0]);
1524  if (!resultType) {
1525  return emitError(unknownLoc, "undefined result type from <id> ")
1526  << operands[0];
1527  }
1528 
1529  SmallVector<Attribute, 4> elements;
1530  elements.reserve(operands.size() - 2);
1531  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1532  auto elementInfo = getConstant(operands[i]);
1533  if (!elementInfo) {
1534  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1535  << operands[i] << " must come from a normal constant";
1536  }
1537  elements.push_back(elementInfo->first);
1538  }
1539 
1540  auto resultID = operands[1];
1541  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1542  auto attr = DenseElementsAttr::get(shapedType, elements);
1543  // For normal constants, we just record the attribute (and its type) for
1544  // later materialization at use sites.
1545  constantMap.try_emplace(resultID, attr, shapedType);
1546  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1547  auto attr = opBuilder.getArrayAttr(elements);
1548  constantMap.try_emplace(resultID, attr, resultType);
1549  } else {
1550  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1551  << resultType;
1552  }
1553 
1554  return success();
1555 }
1556 
1557 LogicalResult
1558 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1559  if (operands.size() < 2) {
1560  return emitError(unknownLoc,
1561  "OpConstantComposite must have type <id> and result <id>");
1562  }
1563  if (operands.size() < 3) {
1564  return emitError(unknownLoc,
1565  "OpConstantComposite must have at least 1 parameter");
1566  }
1567 
1568  Type resultType = getType(operands[0]);
1569  if (!resultType) {
1570  return emitError(unknownLoc, "undefined result type from <id> ")
1571  << operands[0];
1572  }
1573 
1574  auto resultID = operands[1];
1575  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1576 
1577  SmallVector<Attribute, 4> elements;
1578  elements.reserve(operands.size() - 2);
1579  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1580  auto elementInfo = getSpecConstant(operands[i]);
1581  elements.push_back(SymbolRefAttr::get(elementInfo));
1582  }
1583 
1584  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1585  unknownLoc, TypeAttr::get(resultType), symName,
1586  opBuilder.getArrayAttr(elements));
1587  specConstCompositeMap[resultID] = op;
1588 
1589  return success();
1590 }
1591 
1592 LogicalResult
1593 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1594  if (operands.size() < 3)
1595  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1596  "result <id>, and operand opcode");
1597 
1598  uint32_t resultTypeID = operands[0];
1599 
1600  if (!getType(resultTypeID))
1601  return emitError(unknownLoc, "undefined result type from <id> ")
1602  << resultTypeID;
1603 
1604  uint32_t resultID = operands[1];
1605  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1606  auto emplaceResult = specConstOperationMap.try_emplace(
1607  resultID,
1608  SpecConstOperationMaterializationInfo{
1609  enclosedOpcode, resultTypeID,
1610  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1611 
1612  if (!emplaceResult.second)
1613  return emitError(unknownLoc, "value with <id>: ")
1614  << resultID << " is probably defined before.";
1615 
1616  return success();
1617 }
1618 
1619 Value spirv::Deserializer::materializeSpecConstantOperation(
1620  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1621  ArrayRef<uint32_t> enclosedOpOperands) {
1622 
1623  Type resultType = getType(resultTypeID);
1624 
1625  // Instructions wrapped by OpSpecConstantOp need an ID for their
1626  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1627  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1628  // ID in that map is assigned to the result of the enclosed instruction. Note
1629  // that there is no need to update this fake ID since we only need to
1630  // reference the created Value for the enclosed op from the spv::YieldOp
1631  // created later in this method (both of which are the only values in their
1632  // region: the SpecConstantOperation's region). If we encounter another
1633  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1634  // previous Value assigned to it isn't visible in the current scope anyway.
1635  DenseMap<uint32_t, Value> newValueMap;
1636  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1637  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1638 
1639  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1640  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1641  enclosedOpResultTypeAndOperands.push_back(fakeID);
1642  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1643  enclosedOpOperands.end());
1644 
1645  // Process enclosed instruction before creating the enclosing
1646  // specConstantOperation (and its region). This way, references to constants,
1647  // global variables, and spec constants will be materialized outside the new
1648  // op's region. For more info, see Deserializer::getValue's implementation.
1649  if (failed(
1650  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1651  return Value();
1652 
1653  // Since the enclosed op is emitted in the current block, split it in a
1654  // separate new block.
1655  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1656 
1657  auto loc = createFileLineColLoc(opBuilder);
1658  auto specConstOperationOp =
1659  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1660 
1661  Region &body = specConstOperationOp.getBody();
1662  // Move the new block into SpecConstantOperation's body.
1663  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1664  Region::iterator(enclosedBlock));
1665  Block &block = body.back();
1666 
1667  // RAII guard to reset the insertion point to the module's region after
1668  // deserializing the body of the specConstantOperation.
1669  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1670  opBuilder.setInsertionPointToEnd(&block);
1671 
1672  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1673  return specConstOperationOp.getResult();
1674 }
1675 
1676 LogicalResult
1677 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1678  if (operands.size() != 2) {
1679  return emitError(unknownLoc,
1680  "OpConstantNull must have type <id> and result <id>");
1681  }
1682 
1683  Type resultType = getType(operands[0]);
1684  if (!resultType) {
1685  return emitError(unknownLoc, "undefined result type from <id> ")
1686  << operands[0];
1687  }
1688 
1689  auto resultID = operands[1];
1690  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1691  auto attr = opBuilder.getZeroAttr(resultType);
1692  // For normal constants, we just record the attribute (and its type) for
1693  // later materialization at use sites.
1694  constantMap.try_emplace(resultID, attr, resultType);
1695  return success();
1696  }
1697 
1698  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1699  << resultType;
1700 }
1701 
1702 //===----------------------------------------------------------------------===//
1703 // Control flow
1704 //===----------------------------------------------------------------------===//
1705 
1706 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1707  if (auto *block = getBlock(id)) {
1708  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1709  << " @ " << block << "\n");
1710  return block;
1711  }
1712 
1713  // We don't know where this block will be placed finally (in a
1714  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1715  // function for now and sort out the proper place later.
1716  auto *block = curFunction->addBlock();
1717  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1718  << " @ " << block << "\n");
1719  return blockMap[id] = block;
1720 }
1721 
1722 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1723  if (!curBlock) {
1724  return emitError(unknownLoc, "OpBranch must appear inside a block");
1725  }
1726 
1727  if (operands.size() != 1) {
1728  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1729  }
1730 
1731  auto *target = getOrCreateBlock(operands[0]);
1732  auto loc = createFileLineColLoc(opBuilder);
1733  // The preceding instruction for the OpBranch instruction could be an
1734  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1735  // the same OpLine information.
1736  opBuilder.create<spirv::BranchOp>(loc, target);
1737 
1738  clearDebugLine();
1739  return success();
1740 }
1741 
1742 LogicalResult
1743 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1744  if (!curBlock) {
1745  return emitError(unknownLoc,
1746  "OpBranchConditional must appear inside a block");
1747  }
1748 
1749  if (operands.size() != 3 && operands.size() != 5) {
1750  return emitError(unknownLoc,
1751  "OpBranchConditional must have condition, true label, "
1752  "false label, and optionally two branch weights");
1753  }
1754 
1755  auto condition = getValue(operands[0]);
1756  auto *trueBlock = getOrCreateBlock(operands[1]);
1757  auto *falseBlock = getOrCreateBlock(operands[2]);
1758 
1759  std::optional<std::pair<uint32_t, uint32_t>> weights;
1760  if (operands.size() == 5) {
1761  weights = std::make_pair(operands[3], operands[4]);
1762  }
1763  // The preceding instruction for the OpBranchConditional instruction could be
1764  // an OpSelectionMerge instruction, in this case they will have the same
1765  // OpLine information.
1766  auto loc = createFileLineColLoc(opBuilder);
1767  opBuilder.create<spirv::BranchConditionalOp>(
1768  loc, condition, trueBlock,
1769  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1770  /*falseArguments=*/ArrayRef<Value>(), weights);
1771 
1772  clearDebugLine();
1773  return success();
1774 }
1775 
1776 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1777  if (!curFunction) {
1778  return emitError(unknownLoc, "OpLabel must appear inside a function");
1779  }
1780 
1781  if (operands.size() != 1) {
1782  return emitError(unknownLoc, "OpLabel should only have result <id>");
1783  }
1784 
1785  auto labelID = operands[0];
1786  // We may have forward declared this block.
1787  auto *block = getOrCreateBlock(labelID);
1788  LLVM_DEBUG(logger.startLine()
1789  << "[block] populating block " << block << "\n");
1790  // If we have seen this block, make sure it was just a forward declaration.
1791  assert(block->empty() && "re-deserialize the same block!");
1792 
1793  opBuilder.setInsertionPointToStart(block);
1794  blockMap[labelID] = curBlock = block;
1795 
1796  return success();
1797 }
1798 
1799 LogicalResult
1800 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1801  if (!curBlock) {
1802  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1803  }
1804 
1805  if (operands.size() < 2) {
1806  return emitError(
1807  unknownLoc,
1808  "OpSelectionMerge must specify merge target and selection control");
1809  }
1810 
1811  auto *mergeBlock = getOrCreateBlock(operands[0]);
1812  auto loc = createFileLineColLoc(opBuilder);
1813  auto selectionControl = operands[1];
1814 
1815  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1816  .second) {
1817  return emitError(
1818  unknownLoc,
1819  "a block cannot have more than one OpSelectionMerge instruction");
1820  }
1821 
1822  return success();
1823 }
1824 
1825 LogicalResult
1826 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1827  if (!curBlock) {
1828  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1829  }
1830 
1831  if (operands.size() < 3) {
1832  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1833  "continue target and loop control");
1834  }
1835 
1836  auto *mergeBlock = getOrCreateBlock(operands[0]);
1837  auto *continueBlock = getOrCreateBlock(operands[1]);
1838  auto loc = createFileLineColLoc(opBuilder);
1839  uint32_t loopControl = operands[2];
1840 
1841  if (!blockMergeInfo
1842  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1843  .second) {
1844  return emitError(
1845  unknownLoc,
1846  "a block cannot have more than one OpLoopMerge instruction");
1847  }
1848 
1849  return success();
1850 }
1851 
1852 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1853  if (!curBlock) {
1854  return emitError(unknownLoc, "OpPhi must appear in a block");
1855  }
1856 
1857  if (operands.size() < 4) {
1858  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1859  "and variable-parent pairs");
1860  }
1861 
1862  // Create a block argument for this OpPhi instruction.
1863  Type blockArgType = getType(operands[0]);
1864  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1865  valueMap[operands[1]] = blockArg;
1866  LLVM_DEBUG(logger.startLine()
1867  << "[phi] created block argument " << blockArg
1868  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1869 
1870  // For each (value, predecessor) pair, insert the value to the predecessor's
1871  // blockPhiInfo entry so later we can fix the block argument there.
1872  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1873  uint32_t value = operands[i];
1874  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1875  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1876  blockPhiInfo[predecessorTargetPair].push_back(value);
1877  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1878  << " with arg id = " << value << "\n");
1879  }
1880 
1881  return success();
1882 }
1883 
1884 namespace {
1885 /// A class for putting all blocks in a structured selection/loop in a
1886 /// spirv.mlir.selection/spirv.mlir.loop op.
1887 class ControlFlowStructurizer {
1888 public:
1889 #ifndef NDEBUG
1890  ControlFlowStructurizer(Location loc, uint32_t control,
1891  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1892  Block *merge, Block *cont,
1893  llvm::ScopedPrinter &logger)
1894  : location(loc), control(control), blockMergeInfo(mergeInfo),
1895  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1896  logger(logger) {}
1897 #else
1898  ControlFlowStructurizer(Location loc, uint32_t control,
1899  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1900  Block *merge, Block *cont)
1901  : location(loc), control(control), blockMergeInfo(mergeInfo),
1902  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1903 #endif
1904 
1905  /// Structurizes the loop at the given `headerBlock`.
1906  ///
1907  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1908  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1909  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1910  /// method will also update `mergeInfo` by remapping all blocks inside to the
1911  /// newly cloned ones inside structured control flow op's regions.
1912  LogicalResult structurize();
1913 
1914 private:
1915  /// Creates a new spirv.mlir.selection op at the beginning of the
1916  /// `mergeBlock`.
1917  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1918 
1919  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1920  spirv::LoopOp createLoopOp(uint32_t loopControl);
1921 
1922  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1923  void collectBlocksInConstruct();
1924 
1925  Location location;
1926  uint32_t control;
1927 
1928  spirv::BlockMergeInfoMap &blockMergeInfo;
1929 
1930  Block *headerBlock;
1931  Block *mergeBlock;
1932  Block *continueBlock; // nullptr for spirv.mlir.selection
1933 
1934  SetVector<Block *> constructBlocks;
1935 
1936 #ifndef NDEBUG
1937  /// A logger used to emit information during the deserialzation process.
1938  llvm::ScopedPrinter &logger;
1939 #endif
1940 };
1941 } // namespace
1942 
1943 spirv::SelectionOp
1944 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1945  // Create a builder and set the insertion point to the beginning of the
1946  // merge block so that the newly created SelectionOp will be inserted there.
1947  OpBuilder builder(&mergeBlock->front());
1948 
1949  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1950  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1951  selectionOp.addMergeBlock(builder);
1952 
1953  return selectionOp;
1954 }
1955 
1956 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1957  // Create a builder and set the insertion point to the beginning of the
1958  // merge block so that the newly created LoopOp will be inserted there.
1959  OpBuilder builder(&mergeBlock->front());
1960 
1961  auto control = static_cast<spirv::LoopControl>(loopControl);
1962  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1963  loopOp.addEntryAndMergeBlock(builder);
1964 
1965  return loopOp;
1966 }
1967 
1968 void ControlFlowStructurizer::collectBlocksInConstruct() {
1969  assert(constructBlocks.empty() && "expected empty constructBlocks");
1970 
1971  // Put the header block in the work list first.
1972  constructBlocks.insert(headerBlock);
1973 
1974  // For each item in the work list, add its successors excluding the merge
1975  // block.
1976  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1977  for (auto *successor : constructBlocks[i]->getSuccessors())
1978  if (successor != mergeBlock)
1979  constructBlocks.insert(successor);
1980  }
1981 }
1982 
1983 LogicalResult ControlFlowStructurizer::structurize() {
1984  Operation *op = nullptr;
1985  bool isLoop = continueBlock != nullptr;
1986  if (isLoop) {
1987  if (auto loopOp = createLoopOp(control))
1988  op = loopOp.getOperation();
1989  } else {
1990  if (auto selectionOp = createSelectionOp(control))
1991  op = selectionOp.getOperation();
1992  }
1993  if (!op)
1994  return failure();
1995  Region &body = op->getRegion(0);
1996 
1997  IRMapping mapper;
1998  // All references to the old merge block should be directed to the
1999  // selection/loop merge block in the SelectionOp/LoopOp's region.
2000  mapper.map(mergeBlock, &body.back());
2001 
2002  collectBlocksInConstruct();
2003 
2004  // We've identified all blocks belonging to the selection/loop's region. Now
2005  // need to "move" them into the selection/loop. Instead of really moving the
2006  // blocks, in the following we copy them and remap all values and branches.
2007  // This is because:
2008  // * Inserting a block into a region requires the block not in any region
2009  // before. But selections/loops can nest so we can create selection/loop ops
2010  // in a nested manner, which means some blocks may already be in a
2011  // selection/loop region when to be moved again.
2012  // * It's much trickier to fix up the branches into and out of the loop's
2013  // region: we need to treat not-moved blocks and moved blocks differently:
2014  // Not-moved blocks jumping to the loop header block need to jump to the
2015  // merge point containing the new loop op but not the loop continue block's
2016  // back edge. Moved blocks jumping out of the loop need to jump to the
2017  // merge block inside the loop region but not other not-moved blocks.
2018  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2019  // logic.
2020 
2021  // Create a corresponding block in the SelectionOp/LoopOp's region for each
2022  // block in this loop construct.
2023  OpBuilder builder(body);
2024  for (auto *block : constructBlocks) {
2025  // Create a block and insert it before the selection/loop merge block in the
2026  // SelectionOp/LoopOp's region.
2027  auto *newBlock = builder.createBlock(&body.back());
2028  mapper.map(block, newBlock);
2029  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2030  << " from block " << block << "\n");
2031  if (!isFnEntryBlock(block)) {
2032  for (BlockArgument blockArg : block->getArguments()) {
2033  auto newArg =
2034  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2035  mapper.map(blockArg, newArg);
2036  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2037  << blockArg << " to " << newArg << "\n");
2038  }
2039  } else {
2040  LLVM_DEBUG(logger.startLine()
2041  << "[cf] block " << block << " is a function entry block\n");
2042  }
2043 
2044  for (auto &op : *block)
2045  newBlock->push_back(op.clone(mapper));
2046  }
2047 
2048  // Go through all ops and remap the operands.
2049  auto remapOperands = [&](Operation *op) {
2050  for (auto &operand : op->getOpOperands())
2051  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2052  operand.set(mappedOp);
2053  for (auto &succOp : op->getBlockOperands())
2054  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2055  succOp.set(mappedOp);
2056  };
2057  for (auto &block : body)
2058  block.walk(remapOperands);
2059 
2060  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2061  // the selection/loop construct into its region. Next we need to fix the
2062  // connections between this new SelectionOp/LoopOp with existing blocks.
2063 
2064  // All existing incoming branches should go to the merge block, where the
2065  // SelectionOp/LoopOp resides right now.
2066  headerBlock->replaceAllUsesWith(mergeBlock);
2067 
2068  LLVM_DEBUG({
2069  logger.startLine() << "[cf] after cloning and fixing references:\n";
2070  headerBlock->getParentOp()->print(logger.getOStream());
2071  logger.startLine() << "\n";
2072  });
2073 
2074  if (isLoop) {
2075  if (!mergeBlock->args_empty()) {
2076  return mergeBlock->getParentOp()->emitError(
2077  "OpPhi in loop merge block unsupported");
2078  }
2079 
2080  // The loop header block may have block arguments. Since now we place the
2081  // loop op inside the old merge block, we need to make sure the old merge
2082  // block has the same block argument list.
2083  for (BlockArgument blockArg : headerBlock->getArguments())
2084  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2085 
2086  // If the loop header block has block arguments, make sure the spirv.Branch
2087  // op matches.
2088  SmallVector<Value, 4> blockArgs;
2089  if (!headerBlock->args_empty())
2090  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2091 
2092  // The loop entry block should have a unconditional branch jumping to the
2093  // loop header block.
2094  builder.setInsertionPointToEnd(&body.front());
2095  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
2096  ArrayRef<Value>(blockArgs));
2097  }
2098 
2099  // Values defined inside the selection region that need to be yielded outside
2100  // the region.
2101  SmallVector<Value> valuesToYield;
2102  // Outside uses of values that were sunk into the selection region. Those uses
2103  // will be replaced with values returned by the SelectionOp.
2104  SmallVector<Value> outsideUses;
2105 
2106  // Move block arguments of the original block (`mergeBlock`) into the merge
2107  // block inside the selection (`body.back()`). Values produced by block
2108  // arguments will be yielded by the selection region. We do not update uses or
2109  // erase original block arguments yet. It will be done later in the code.
2110  //
2111  // Code below is not executed for loops as it would interfere with the logic
2112  // above. Currently block arguments in the merge block are not supported, but
2113  // instead, the code above copies those arguments from the header block into
2114  // the merge block. As such, running the code would yield those copied
2115  // arguments that is most likely not a desired behaviour. This may need to be
2116  // revisited in the future.
2117  if (!isLoop)
2118  for (BlockArgument blockArg : mergeBlock->getArguments()) {
2119  // Create new block arguments in the last block ("merge block") of the
2120  // selection region. We create one argument for each argument in
2121  // `mergeBlock`. This new value will need to be yielded, and the original
2122  // value replaced, so add them to appropriate vectors.
2123  body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2124  valuesToYield.push_back(body.back().getArguments().back());
2125  outsideUses.push_back(blockArg);
2126  }
2127 
2128  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2129  // cleaned up.
2130  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2131  // First we need to drop all operands' references inside all blocks. This is
2132  // needed because we can have blocks referencing SSA values from one another.
2133  for (auto *block : constructBlocks)
2134  block->dropAllReferences();
2135 
2136  // All internal uses should be removed from original blocks by now, so
2137  // whatever is left is an outside use and will need to be yielded from
2138  // the newly created selection / loop region.
2139  for (Block *block : constructBlocks) {
2140  for (Operation &op : *block) {
2141  if (!op.use_empty())
2142  for (Value result : op.getResults()) {
2143  valuesToYield.push_back(mapper.lookupOrNull(result));
2144  outsideUses.push_back(result);
2145  }
2146  }
2147  for (BlockArgument &arg : block->getArguments()) {
2148  if (!arg.use_empty()) {
2149  valuesToYield.push_back(mapper.lookupOrNull(arg));
2150  outsideUses.push_back(arg);
2151  }
2152  }
2153  }
2154 
2155  assert(valuesToYield.size() == outsideUses.size());
2156 
2157  // If we need to yield any values from the selection / loop region we will
2158  // take care of it here.
2159  if (!valuesToYield.empty()) {
2160  LLVM_DEBUG(logger.startLine()
2161  << "[cf] yielding values from the selection / loop region\n");
2162 
2163  // Update `mlir.merge` with values to be yield.
2164  auto mergeOps = body.back().getOps<spirv::MergeOp>();
2165  Operation *merge = llvm::getSingleElement(mergeOps);
2166  assert(merge);
2167  merge->setOperands(valuesToYield);
2168 
2169  // MLIR does not allow changing the number of results of an operation, so
2170  // we create a new SelectionOp / LoopOp with required list of results and
2171  // move the region from the initial SelectionOp / LoopOp. The initial
2172  // operation is then removed. Since we move the region to the new op all
2173  // links between blocks and remapping we have previously done should be
2174  // preserved.
2175  builder.setInsertionPoint(&mergeBlock->front());
2176 
2177  Operation *newOp = nullptr;
2178 
2179  if (isLoop)
2180  newOp = builder.create<spirv::LoopOp>(
2181  location, TypeRange(ValueRange(outsideUses)),
2182  static_cast<spirv::LoopControl>(control));
2183  else
2184  newOp = builder.create<spirv::SelectionOp>(
2185  location, TypeRange(ValueRange(outsideUses)),
2186  static_cast<spirv::SelectionControl>(control));
2187 
2188  newOp->getRegion(0).takeBody(body);
2189 
2190  // Remove initial op and swap the pointer to the newly created one.
2191  op->erase();
2192  op = newOp;
2193 
2194  // Update all outside uses to use results of the SelectionOp / LoopOp and
2195  // remove block arguments from the original merge block.
2196  for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2197  outsideUses[i].replaceAllUsesWith(op->getResult(i));
2198 
2199  // We do not support block arguments in loop merge block. Also running this
2200  // function with loop would break some of the loop specific code above
2201  // dealing with block arguments.
2202  if (!isLoop)
2203  mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2204  }
2205 
2206  // Check that whether some op in the to-be-erased blocks still has uses. Those
2207  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2208  // region. We cannot handle such cases given that once a value is sinked into
2209  // the SelectionOp/LoopOp's region, there is no escape for it.
2210  for (auto *block : constructBlocks) {
2211  for (Operation &op : *block)
2212  if (!op.use_empty())
2213  return op.emitOpError("failed control flow structurization: value has "
2214  "uses outside of the "
2215  "enclosing selection/loop construct");
2216  for (BlockArgument &arg : block->getArguments())
2217  if (!arg.use_empty())
2218  return emitError(arg.getLoc(), "failed control flow structurization: "
2219  "block argument has uses outside of the "
2220  "enclosing selection/loop construct");
2221  }
2222 
2223  // Then erase all old blocks.
2224  for (auto *block : constructBlocks) {
2225  // We've cloned all blocks belonging to this construct into the structured
2226  // control flow op's region. Among these blocks, some may compose another
2227  // selection/loop. If so, they will be recorded within blockMergeInfo.
2228  // We need to update the pointers there to the newly remapped ones so we can
2229  // continue structurizing them later.
2230  //
2231  // We need to walk each block as constructBlocks do not include blocks
2232  // internal to ops already structured within those blocks. It is not
2233  // fully clear to me why the mergeInfo of blocks (yet to be structured)
2234  // inside already structured selections/loops get invalidated and needs
2235  // updating, however the following example code can cause a crash (depending
2236  // on the structuring order), when the most inner selection is being
2237  // structured after the outer selection and loop have been already
2238  // structured:
2239  //
2240  // spirv.mlir.for {
2241  // // ...
2242  // spirv.mlir.selection {
2243  // // ..
2244  // // A selection region that hasn't been yet structured!
2245  // // ..
2246  // }
2247  // // ...
2248  // }
2249  //
2250  // If the loop gets structured after the outer selection, but before the
2251  // inner selection. Moving the already structured selection inside the loop
2252  // will invalidate the mergeInfo of the region that is not yet structured.
2253  // Just going over constructBlocks will not check and updated header blocks
2254  // inside the already structured selection region. Walking block fixes that.
2255  //
2256  // TODO: If structuring was done in a fixed order starting with inner
2257  // most constructs this most likely not be an issue and the whole code
2258  // section could be removed. However, with the current non-deterministic
2259  // order this is not possible.
2260  //
2261  // TODO: The asserts in the following assumes input SPIR-V blob forms
2262  // correctly nested selection/loop constructs. We should relax this and
2263  // support error cases better.
2264  auto updateMergeInfo = [&](Block *block) -> WalkResult {
2265  auto it = blockMergeInfo.find(block);
2266  if (it != blockMergeInfo.end()) {
2267  // Use the original location for nested selection/loop ops.
2268  Location loc = it->second.loc;
2269 
2270  Block *newHeader = mapper.lookupOrNull(block);
2271  if (!newHeader)
2272  return emitError(loc, "failed control flow structurization: nested "
2273  "loop header block should be remapped!");
2274 
2275  Block *newContinue = it->second.continueBlock;
2276  if (newContinue) {
2277  newContinue = mapper.lookupOrNull(newContinue);
2278  if (!newContinue)
2279  return emitError(loc, "failed control flow structurization: nested "
2280  "loop continue block should be remapped!");
2281  }
2282 
2283  Block *newMerge = it->second.mergeBlock;
2284  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2285  newMerge = mappedTo;
2286 
2287  // The iterator should be erased before adding a new entry into
2288  // blockMergeInfo to avoid iterator invalidation.
2289  blockMergeInfo.erase(it);
2290  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2291  newContinue);
2292  }
2293 
2294  return WalkResult::advance();
2295  };
2296 
2297  if (block->walk(updateMergeInfo).wasInterrupted())
2298  return failure();
2299 
2300  // The structured selection/loop's entry block does not have arguments.
2301  // If the function's header block is also part of the structured control
2302  // flow, we cannot just simply erase it because it may contain arguments
2303  // matching the function signature and used by the cloned blocks.
2304  if (isFnEntryBlock(block)) {
2305  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2306  << " to only contain a spirv.Branch op\n");
2307  // Still keep the function entry block for the potential block arguments,
2308  // but replace all ops inside with a branch to the merge block.
2309  block->clear();
2310  builder.setInsertionPointToEnd(block);
2311  builder.create<spirv::BranchOp>(location, mergeBlock);
2312  } else {
2313  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2314  block->erase();
2315  }
2316  }
2317 
2318  LLVM_DEBUG(logger.startLine()
2319  << "[cf] after structurizing construct with header block "
2320  << headerBlock << ":\n"
2321  << *op << "\n");
2322 
2323  return success();
2324 }
2325 
2326 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2327  LLVM_DEBUG({
2328  logger.startLine()
2329  << "//----- [phi] start wiring up block arguments -----//\n";
2330  logger.indent();
2331  });
2332 
2333  OpBuilder::InsertionGuard guard(opBuilder);
2334 
2335  for (const auto &info : blockPhiInfo) {
2336  Block *block = info.first.first;
2337  Block *target = info.first.second;
2338  const BlockPhiInfo &phiInfo = info.second;
2339  LLVM_DEBUG({
2340  logger.startLine() << "[phi] block " << block << "\n";
2341  logger.startLine() << "[phi] before creating block argument:\n";
2342  block->getParentOp()->print(logger.getOStream());
2343  logger.startLine() << "\n";
2344  });
2345 
2346  // Set insertion point to before this block's terminator early because we
2347  // may materialize ops via getValue() call.
2348  auto *op = block->getTerminator();
2349  opBuilder.setInsertionPoint(op);
2350 
2351  SmallVector<Value, 4> blockArgs;
2352  blockArgs.reserve(phiInfo.size());
2353  for (uint32_t valueId : phiInfo) {
2354  if (Value value = getValue(valueId)) {
2355  blockArgs.push_back(value);
2356  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2357  << " id = " << valueId << "\n");
2358  } else {
2359  return emitError(unknownLoc, "OpPhi references undefined value!");
2360  }
2361  }
2362 
2363  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2364  // Replace the previous branch op with a new one with block arguments.
2365  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2366  blockArgs);
2367  branchOp.erase();
2368  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2369  assert((branchCondOp.getTrueBlock() == target ||
2370  branchCondOp.getFalseBlock() == target) &&
2371  "expected target to be either the true or false target");
2372  if (target == branchCondOp.getTrueTarget())
2373  opBuilder.create<spirv::BranchConditionalOp>(
2374  branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2375  branchCondOp.getFalseBlockArguments(),
2376  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2377  branchCondOp.getFalseTarget());
2378  else
2379  opBuilder.create<spirv::BranchConditionalOp>(
2380  branchCondOp.getLoc(), branchCondOp.getCondition(),
2381  branchCondOp.getTrueBlockArguments(), blockArgs,
2382  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2383  branchCondOp.getFalseBlock());
2384 
2385  branchCondOp.erase();
2386  } else {
2387  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2388  }
2389 
2390  LLVM_DEBUG({
2391  logger.startLine() << "[phi] after creating block argument:\n";
2392  block->getParentOp()->print(logger.getOStream());
2393  logger.startLine() << "\n";
2394  });
2395  }
2396  blockPhiInfo.clear();
2397 
2398  LLVM_DEBUG({
2399  logger.unindent();
2400  logger.startLine()
2401  << "//--- [phi] completed wiring up block arguments ---//\n";
2402  });
2403  return success();
2404 }
2405 
2406 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2407  // Create a copy, so we can modify keys in the original.
2408  BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2409  for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2410  it != e; ++it) {
2411  auto &[block, mergeInfo] = *it;
2412 
2413  // Skip processing loop regions. For loop regions continueBlock is non-null.
2414  if (mergeInfo.continueBlock)
2415  continue;
2416 
2417  if (!block->mightHaveTerminator())
2418  continue;
2419 
2420  Operation *terminator = block->getTerminator();
2421  assert(terminator);
2422 
2423  if (!isa<spirv::BranchConditionalOp>(terminator))
2424  continue;
2425 
2426  // Check if the current header block is a merge block of another construct.
2427  bool splitHeaderMergeBlock = false;
2428  for (const auto &[_, mergeInfo] : blockMergeInfo) {
2429  if (mergeInfo.mergeBlock == block)
2430  splitHeaderMergeBlock = true;
2431  }
2432 
2433  // Do not split a block that only contains a conditional branch, unless it
2434  // is also a merge block of another construct - in that case we want to
2435  // split the block. We do not want two constructs to share header / merge
2436  // block.
2437  if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2438  Block *newBlock = block->splitBlock(terminator);
2439  OpBuilder builder(block, block->end());
2440  builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
2441 
2442  // After splitting we need to update the map to use the new block as a
2443  // header.
2444  blockMergeInfo.erase(block);
2445  blockMergeInfo.try_emplace(newBlock, mergeInfo);
2446  }
2447  }
2448 
2449  return success();
2450 }
2451 
2452 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2453  if (!options.enableControlFlowStructurization) {
2454  LLVM_DEBUG(
2455  {
2456  logger.startLine()
2457  << "//----- [cf] skip structurizing control flow -----//\n";
2458  logger.indent();
2459  });
2460  return success();
2461  }
2462 
2463  LLVM_DEBUG({
2464  logger.startLine()
2465  << "//----- [cf] start structurizing control flow -----//\n";
2466  logger.indent();
2467  });
2468 
2469  LLVM_DEBUG({
2470  logger.startLine() << "[cf] split conditional blocks\n";
2471  logger.startLine() << "\n";
2472  });
2473 
2474  if (failed(splitConditionalBlocks())) {
2475  return failure();
2476  }
2477 
2478  // TODO: This loop is non-deterministic. Iteration order may vary between runs
2479  // for the same shader as the key to the map is a pointer. See:
2480  // https://github.com/llvm/llvm-project/issues/128547
2481  while (!blockMergeInfo.empty()) {
2482  Block *headerBlock = blockMergeInfo.begin()->first;
2483  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2484 
2485  LLVM_DEBUG({
2486  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2487  headerBlock->print(logger.getOStream());
2488  logger.startLine() << "\n";
2489  });
2490 
2491  auto *mergeBlock = mergeInfo.mergeBlock;
2492  assert(mergeBlock && "merge block cannot be nullptr");
2493  if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2494  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2495  LLVM_DEBUG({
2496  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2497  mergeBlock->print(logger.getOStream());
2498  logger.startLine() << "\n";
2499  });
2500 
2501  auto *continueBlock = mergeInfo.continueBlock;
2502  LLVM_DEBUG(if (continueBlock) {
2503  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2504  continueBlock->print(logger.getOStream());
2505  logger.startLine() << "\n";
2506  });
2507  // Erase this case before calling into structurizer, who will update
2508  // blockMergeInfo.
2509  blockMergeInfo.erase(blockMergeInfo.begin());
2510  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2511  blockMergeInfo, headerBlock,
2512  mergeBlock, continueBlock
2513 #ifndef NDEBUG
2514  ,
2515  logger
2516 #endif
2517  );
2518  if (failed(structurizer.structurize()))
2519  return failure();
2520  }
2521 
2522  LLVM_DEBUG({
2523  logger.unindent();
2524  logger.startLine()
2525  << "//--- [cf] completed structurizing control flow ---//\n";
2526  });
2527  return success();
2528 }
2529 
2530 //===----------------------------------------------------------------------===//
2531 // Debug
2532 //===----------------------------------------------------------------------===//
2533 
2534 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2535  if (!debugLine)
2536  return unknownLoc;
2537 
2538  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2539  if (fileName.empty())
2540  fileName = "<unknown>";
2541  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2542  debugLine->column);
2543 }
2544 
2545 LogicalResult
2546 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2547  // According to SPIR-V spec:
2548  // "This location information applies to the instructions physically
2549  // following this instruction, up to the first occurrence of any of the
2550  // following: the next end of block, the next OpLine instruction, or the next
2551  // OpNoLine instruction."
2552  if (operands.size() != 3)
2553  return emitError(unknownLoc, "OpLine must have 3 operands");
2554  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2555  return success();
2556 }
2557 
2558 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2559 
2560 LogicalResult
2561 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2562  if (operands.size() < 2)
2563  return emitError(unknownLoc, "OpString needs at least 2 operands");
2564 
2565  if (!debugInfoMap.lookup(operands[0]).empty())
2566  return emitError(unknownLoc,
2567  "duplicate debug string found for result <id> ")
2568  << operands[0];
2569 
2570  unsigned wordIndex = 1;
2571  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2572  if (wordIndex != operands.size())
2573  return emitError(unknownLoc,
2574  "unexpected trailing words in OpString instruction");
2575 
2576  debugInfoMap[operands[0]] = debugString;
2577  return success();
2578 }
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
int64_t rows
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:308
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
void print(raw_ostream &os)
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:250
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
void push_back(Operation *op)
Definition: Block.h:149
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:161
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:695
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockListType::iterator iterator
Definition: Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:257
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:170
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:449
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:506
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:795
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
Definition: Deserializer.h:61
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.