MLIR  21.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context,
54  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55  module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56 #ifndef NDEBUG
57  ,
58  logger(llvm::dbgs())
59 #endif
60 {
61 }
62 
64  LLVM_DEBUG({
65  logger.resetIndent();
66  logger.startLine()
67  << "//+++---------- start deserialization ----------+++//\n";
68  });
69 
70  if (failed(processHeader()))
71  return failure();
72 
73  spirv::Opcode opcode = spirv::Opcode::OpNop;
74  ArrayRef<uint32_t> operands;
75  auto binarySize = binary.size();
76  while (curOffset < binarySize) {
77  // Slice the next instruction out and populate `opcode` and `operands`.
78  // Internally this also updates `curOffset`.
79  if (failed(sliceInstruction(opcode, operands)))
80  return failure();
81 
82  if (failed(processInstruction(opcode, operands)))
83  return failure();
84  }
85 
86  assert(curOffset == binarySize &&
87  "deserializer should never index beyond the binary end");
88 
89  for (auto &deferred : deferredInstructions) {
90  if (failed(processInstruction(deferred.first, deferred.second, false))) {
91  return failure();
92  }
93  }
94 
95  attachVCETriple();
96 
97  LLVM_DEBUG(logger.startLine()
98  << "//+++-------- completed deserialization --------+++//\n");
99  return success();
100 }
101 
103  return std::move(module);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Module structure
108 //===----------------------------------------------------------------------===//
109 
110 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111  OpBuilder builder(context);
112  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113  spirv::ModuleOp::build(builder, state);
114  return cast<spirv::ModuleOp>(Operation::create(state));
115 }
116 
117 LogicalResult spirv::Deserializer::processHeader() {
118  if (binary.size() < spirv::kHeaderWordCount)
119  return emitError(unknownLoc,
120  "SPIR-V binary module must have a 5-word header");
121 
122  if (binary[0] != spirv::kMagicNumber)
123  return emitError(unknownLoc, "incorrect magic number");
124 
125  // Version number bytes: 0 | major number | minor number | 0
126  uint32_t majorVersion = (binary[1] << 8) >> 24;
127  uint32_t minorVersion = (binary[1] << 16) >> 24;
128  if (majorVersion == 1) {
129  switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
131  case v: \
132  version = spirv::Version::V_1_##v; \
133  break
134 
135  MIN_VERSION_CASE(0);
136  MIN_VERSION_CASE(1);
137  MIN_VERSION_CASE(2);
138  MIN_VERSION_CASE(3);
139  MIN_VERSION_CASE(4);
140  MIN_VERSION_CASE(5);
141 #undef MIN_VERSION_CASE
142  default:
143  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
144  << minorVersion;
145  }
146  } else {
147  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
148  << majorVersion;
149  }
150 
151  // TODO: generator number, bound, schema
152  curOffset = spirv::kHeaderWordCount;
153  return success();
154 }
155 
156 LogicalResult
157 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158  if (operands.size() != 1)
159  return emitError(unknownLoc, "OpCapability must have one parameter");
160 
161  auto cap = spirv::symbolizeCapability(operands[0]);
162  if (!cap)
163  return emitError(unknownLoc, "unknown capability: ") << operands[0];
164 
165  capabilities.insert(*cap);
166  return success();
167 }
168 
169 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170  if (words.empty()) {
171  return emitError(
172  unknownLoc,
173  "OpExtension must have a literal string for the extension name");
174  }
175 
176  unsigned wordIndex = 0;
177  StringRef extName = decodeStringLiteral(words, wordIndex);
178  if (wordIndex != words.size())
179  return emitError(unknownLoc,
180  "unexpected trailing words in OpExtension instruction");
181  auto ext = spirv::symbolizeExtension(extName);
182  if (!ext)
183  return emitError(unknownLoc, "unknown extension: ") << extName;
184 
185  extensions.insert(*ext);
186  return success();
187 }
188 
189 LogicalResult
190 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191  if (words.size() < 2) {
192  return emitError(unknownLoc,
193  "OpExtInstImport must have a result <id> and a literal "
194  "string for the extended instruction set name");
195  }
196 
197  unsigned wordIndex = 1;
198  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199  if (wordIndex != words.size()) {
200  return emitError(unknownLoc,
201  "unexpected trailing words in OpExtInstImport");
202  }
203  return success();
204 }
205 
206 void spirv::Deserializer::attachVCETriple() {
207  (*module)->setAttr(
208  spirv::ModuleOp::getVCETripleAttrName(),
209  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
210  extensions.getArrayRef(), context));
211 }
212 
213 LogicalResult
214 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215  if (operands.size() != 2)
216  return emitError(unknownLoc, "OpMemoryModel must have two operands");
217 
218  (*module)->setAttr(
219  module->getAddressingModelAttrName(),
220  opBuilder.getAttr<spirv::AddressingModelAttr>(
221  static_cast<spirv::AddressingModel>(operands.front())));
222 
223  (*module)->setAttr(module->getMemoryModelAttrName(),
224  opBuilder.getAttr<spirv::MemoryModelAttr>(
225  static_cast<spirv::MemoryModel>(operands.back())));
226 
227  return success();
228 }
229 
230 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
232  Location loc, OpBuilder &opBuilder,
234  StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
235  if (words.size() != 4) {
236  return emitError(loc, "OpDecoration with ")
237  << decorationName << "needs a cache control integer literal and a "
238  << cacheControlKind << " cache control literal";
239  }
240  unsigned cacheLevel = words[2];
241  auto cacheControlAttr = static_cast<EnumTy>(words[3]);
242  auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
244  if (auto attrList =
245  llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
246  llvm::append_range(attrs, attrList);
247  attrs.push_back(value);
248  decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
249  return success();
250 }
251 
252 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
253  // TODO: This function should also be auto-generated. For now, since only a
254  // few decorations are processed/handled in a meaningful manner, going with a
255  // manual implementation.
256  if (words.size() < 2) {
257  return emitError(
258  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
259  }
260  auto decorationName =
261  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
262  if (decorationName.empty()) {
263  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
264  }
265  auto symbol = getSymbolDecoration(decorationName);
266  switch (static_cast<spirv::Decoration>(words[1])) {
267  case spirv::Decoration::FPFastMathMode:
268  if (words.size() != 3) {
269  return emitError(unknownLoc, "OpDecorate with ")
270  << decorationName << " needs a single integer literal";
271  }
272  decorations[words[0]].set(
273  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
274  static_cast<FPFastMathMode>(words[2])));
275  break;
276  case spirv::Decoration::FPRoundingMode:
277  if (words.size() != 3) {
278  return emitError(unknownLoc, "OpDecorate with ")
279  << decorationName << " needs a single integer literal";
280  }
281  decorations[words[0]].set(
282  symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
283  static_cast<FPRoundingMode>(words[2])));
284  break;
285  case spirv::Decoration::DescriptorSet:
286  case spirv::Decoration::Binding:
287  if (words.size() != 3) {
288  return emitError(unknownLoc, "OpDecorate with ")
289  << decorationName << " needs a single integer literal";
290  }
291  decorations[words[0]].set(
292  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
293  break;
294  case spirv::Decoration::BuiltIn:
295  if (words.size() != 3) {
296  return emitError(unknownLoc, "OpDecorate with ")
297  << decorationName << " needs a single integer literal";
298  }
299  decorations[words[0]].set(
300  symbol, opBuilder.getStringAttr(
301  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
302  break;
303  case spirv::Decoration::ArrayStride:
304  if (words.size() != 3) {
305  return emitError(unknownLoc, "OpDecorate with ")
306  << decorationName << " needs a single integer literal";
307  }
308  typeDecorations[words[0]] = words[2];
309  break;
310  case spirv::Decoration::LinkageAttributes: {
311  if (words.size() < 4) {
312  return emitError(unknownLoc, "OpDecorate with ")
313  << decorationName
314  << " needs at least 1 string and 1 integer literal";
315  }
316  // LinkageAttributes has two parameters ["linkageName", linkageType]
317  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
318  // "linkageName" is a stringliteral encoded as uint32_t,
319  // hence the size of name is variable length which results in words.size()
320  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
321  // 3 + ceildiv(strlen(name), 4).
322  unsigned wordIndex = 2;
323  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
324  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
325  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
326  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
327  StringAttr::get(context, linkageName), linkageTypeAttr);
328  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
329  break;
330  }
331  case spirv::Decoration::Aliased:
332  case spirv::Decoration::AliasedPointer:
333  case spirv::Decoration::Block:
334  case spirv::Decoration::BufferBlock:
335  case spirv::Decoration::Flat:
336  case spirv::Decoration::NonReadable:
337  case spirv::Decoration::NonWritable:
338  case spirv::Decoration::NoPerspective:
339  case spirv::Decoration::NoSignedWrap:
340  case spirv::Decoration::NoUnsignedWrap:
341  case spirv::Decoration::RelaxedPrecision:
342  case spirv::Decoration::Restrict:
343  case spirv::Decoration::RestrictPointer:
344  case spirv::Decoration::NoContraction:
345  case spirv::Decoration::Constant:
346  if (words.size() != 2) {
347  return emitError(unknownLoc, "OpDecoration with ")
348  << decorationName << "needs a single target <id>";
349  }
350  // Block decoration does not affect spirv.struct type, but is still stored
351  // for verification.
352  // TODO: Update StructType to contain this information since
353  // it is needed for many validation rules.
354  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
355  break;
356  case spirv::Decoration::Location:
357  case spirv::Decoration::SpecId:
358  if (words.size() != 3) {
359  return emitError(unknownLoc, "OpDecoration with ")
360  << decorationName << "needs a single integer literal";
361  }
362  decorations[words[0]].set(
363  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
364  break;
365  case spirv::Decoration::CacheControlLoadINTEL: {
366  LogicalResult res = deserializeCacheControlDecoration<
367  CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
369  "load");
370  if (failed(res))
371  return res;
372  break;
373  }
374  case spirv::Decoration::CacheControlStoreINTEL: {
375  LogicalResult res = deserializeCacheControlDecoration<
376  CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
378  "store");
379  if (failed(res))
380  return res;
381  break;
382  }
383  default:
384  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
385  }
386  return success();
387 }
388 
389 LogicalResult
390 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
391  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
392  if (words.size() < 3) {
393  return emitError(unknownLoc,
394  "OpMemberDecorate must have at least 3 operands");
395  }
396 
397  auto decoration = static_cast<spirv::Decoration>(words[2]);
398  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
399  return emitError(unknownLoc,
400  " missing offset specification in OpMemberDecorate with "
401  "Offset decoration");
402  }
403  ArrayRef<uint32_t> decorationOperands;
404  if (words.size() > 3) {
405  decorationOperands = words.slice(3);
406  }
407  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
408  return success();
409 }
410 
411 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
412  if (words.size() < 3) {
413  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
414  }
415  unsigned wordIndex = 2;
416  auto name = decodeStringLiteral(words, wordIndex);
417  if (wordIndex != words.size()) {
418  return emitError(unknownLoc,
419  "unexpected trailing words in OpMemberName instruction");
420  }
421  memberNameMap[words[0]][words[1]] = name;
422  return success();
423 }
424 
425 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
426  uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
427  if (!decorations.contains(argID)) {
428  argAttrs[argIndex] = DictionaryAttr::get(context, {});
429  return success();
430  }
431 
432  spirv::DecorationAttr foundDecorationAttr;
433  for (NamedAttribute decAttr : decorations[argID]) {
434  for (auto decoration :
435  {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436  spirv::Decoration::AliasedPointer,
437  spirv::Decoration::RestrictPointer}) {
438 
439  if (decAttr.getName() !=
440  getSymbolDecoration(stringifyDecoration(decoration)))
441  continue;
442 
443  if (foundDecorationAttr)
444  return emitError(unknownLoc,
445  "more than one Aliased/Restrict decorations for "
446  "function argument with result <id> ")
447  << argID;
448 
449  foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
450  break;
451  }
452 
453  if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
454  spirv::Decoration::RelaxedPrecision))) {
455  // TODO: Current implementation supports only one decoration per function
456  // parameter so RelaxedPrecision cannot be applied at the same time as,
457  // for example, Aliased/Restrict/etc. This should be relaxed to allow any
458  // combination of decoration allowed by the spec to be supported.
459  if (foundDecorationAttr)
460  return emitError(unknownLoc, "already found a decoration for function "
461  "argument with result <id> ")
462  << argID;
463 
464  foundDecorationAttr = spirv::DecorationAttr::get(
465  context, spirv::Decoration::RelaxedPrecision);
466  }
467  }
468 
469  if (!foundDecorationAttr)
470  return emitError(unknownLoc, "unimplemented decoration support for "
471  "function argument with result <id> ")
472  << argID;
473 
474  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
475  foundDecorationAttr);
476  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
477  return success();
478 }
479 
480 LogicalResult
481 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
482  if (curFunction) {
483  return emitError(unknownLoc, "found function inside function");
484  }
485 
486  // Get the result type
487  if (operands.size() != 4) {
488  return emitError(unknownLoc, "OpFunction must have 4 parameters");
489  }
490  Type resultType = getType(operands[0]);
491  if (!resultType) {
492  return emitError(unknownLoc, "undefined result type from <id> ")
493  << operands[0];
494  }
495 
496  uint32_t fnID = operands[1];
497  if (funcMap.count(fnID)) {
498  return emitError(unknownLoc, "duplicate function definition/declaration");
499  }
500 
501  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
502  if (!fnControl) {
503  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
504  }
505 
506  Type fnType = getType(operands[3]);
507  if (!fnType || !isa<FunctionType>(fnType)) {
508  return emitError(unknownLoc, "unknown function type from <id> ")
509  << operands[3];
510  }
511  auto functionType = cast<FunctionType>(fnType);
512 
513  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
514  (functionType.getNumResults() == 1 &&
515  functionType.getResult(0) != resultType)) {
516  return emitError(unknownLoc, "mismatch in function type ")
517  << functionType << " and return type " << resultType << " specified";
518  }
519 
520  std::string fnName = getFunctionSymbol(fnID);
521  auto funcOp = opBuilder.create<spirv::FuncOp>(
522  unknownLoc, fnName, functionType, fnControl.value());
523  // Processing other function attributes.
524  if (decorations.count(fnID)) {
525  for (auto attr : decorations[fnID].getAttrs()) {
526  funcOp->setAttr(attr.getName(), attr.getValue());
527  }
528  }
529  curFunction = funcMap[fnID] = funcOp;
530  auto *entryBlock = funcOp.addEntryBlock();
531  LLVM_DEBUG({
532  logger.startLine()
533  << "//===-------------------------------------------===//\n";
534  logger.startLine() << "[fn] name: " << fnName << "\n";
535  logger.startLine() << "[fn] type: " << fnType << "\n";
536  logger.startLine() << "[fn] ID: " << fnID << "\n";
537  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
538  logger.indent();
539  });
540 
541  SmallVector<Attribute> argAttrs;
542  argAttrs.resize(functionType.getNumInputs());
543 
544  // Parse the op argument instructions
545  if (functionType.getNumInputs()) {
546  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547  auto argType = functionType.getInput(i);
548  spirv::Opcode opcode = spirv::Opcode::OpNop;
549  ArrayRef<uint32_t> operands;
550  if (failed(sliceInstruction(opcode, operands,
551  spirv::Opcode::OpFunctionParameter))) {
552  return failure();
553  }
554  if (opcode != spirv::Opcode::OpFunctionParameter) {
555  return emitError(
556  unknownLoc,
557  "missing OpFunctionParameter instruction for argument ")
558  << i;
559  }
560  if (operands.size() != 2) {
561  return emitError(
562  unknownLoc,
563  "expected result type and result <id> for OpFunctionParameter");
564  }
565  auto argDefinedType = getType(operands[0]);
566  if (!argDefinedType || argDefinedType != argType) {
567  return emitError(unknownLoc,
568  "mismatch in argument type between function type "
569  "definition ")
570  << functionType << " and argument type definition "
571  << argDefinedType << " at argument " << i;
572  }
573  if (getValue(operands[1])) {
574  return emitError(unknownLoc, "duplicate definition of result <id> ")
575  << operands[1];
576  }
577  if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
578  return failure();
579  }
580 
581  auto argValue = funcOp.getArgument(i);
582  valueMap[operands[1]] = argValue;
583  }
584  }
585 
586  if (llvm::any_of(argAttrs, [](Attribute attr) {
587  auto argAttr = cast<DictionaryAttr>(attr);
588  return !argAttr.empty();
589  }))
590  funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
591 
592  // entryBlock is needed to access the arguments, Once that is done, we can
593  // erase the block for functions with 'Import' LinkageAttributes, since these
594  // are essentially function declarations, so they have no body.
595  auto linkageAttr = funcOp.getLinkageAttributes();
596  auto hasImportLinkage =
597  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
598  spirv::LinkageType::Import);
599  if (hasImportLinkage)
600  funcOp.eraseBody();
601 
602  // RAII guard to reset the insertion point to the module's region after
603  // deserializing the body of this function.
604  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
605 
606  spirv::Opcode opcode = spirv::Opcode::OpNop;
607  ArrayRef<uint32_t> instOperands;
608 
609  // Special handling for the entry block. We need to make sure it starts with
610  // an OpLabel instruction. The entry block takes the same parameters as the
611  // function. All other blocks do not take any parameter. We have already
612  // created the entry block, here we need to register it to the correct label
613  // <id>.
614  if (failed(sliceInstruction(opcode, instOperands,
615  spirv::Opcode::OpFunctionEnd))) {
616  return failure();
617  }
618  if (opcode == spirv::Opcode::OpFunctionEnd) {
619  return processFunctionEnd(instOperands);
620  }
621  if (opcode != spirv::Opcode::OpLabel) {
622  return emitError(unknownLoc, "a basic block must start with OpLabel");
623  }
624  if (instOperands.size() != 1) {
625  return emitError(unknownLoc, "OpLabel should only have result <id>");
626  }
627  blockMap[instOperands[0]] = entryBlock;
628  if (failed(processLabel(instOperands))) {
629  return failure();
630  }
631 
632  // Then process all the other instructions in the function until we hit
633  // OpFunctionEnd.
634  while (succeeded(sliceInstruction(opcode, instOperands,
635  spirv::Opcode::OpFunctionEnd)) &&
636  opcode != spirv::Opcode::OpFunctionEnd) {
637  if (failed(processInstruction(opcode, instOperands))) {
638  return failure();
639  }
640  }
641  if (opcode != spirv::Opcode::OpFunctionEnd) {
642  return failure();
643  }
644 
645  return processFunctionEnd(instOperands);
646 }
647 
648 LogicalResult
649 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
650  // Process OpFunctionEnd.
651  if (!operands.empty()) {
652  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
653  }
654 
655  // Wire up block arguments from OpPhi instructions.
656  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
657  // ops.
658  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
659  return failure();
660  }
661 
662  curBlock = nullptr;
663  curFunction = std::nullopt;
664 
665  LLVM_DEBUG({
666  logger.unindent();
667  logger.startLine()
668  << "//===-------------------------------------------===//\n";
669  });
670  return success();
671 }
672 
673 std::optional<std::pair<Attribute, Type>>
674 spirv::Deserializer::getConstant(uint32_t id) {
675  auto constIt = constantMap.find(id);
676  if (constIt == constantMap.end())
677  return std::nullopt;
678  return constIt->getSecond();
679 }
680 
681 std::optional<spirv::SpecConstOperationMaterializationInfo>
682 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
683  auto constIt = specConstOperationMap.find(id);
684  if (constIt == specConstOperationMap.end())
685  return std::nullopt;
686  return constIt->getSecond();
687 }
688 
689 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
690  auto funcName = nameMap.lookup(id).str();
691  if (funcName.empty()) {
692  funcName = "spirv_fn_" + std::to_string(id);
693  }
694  return funcName;
695 }
696 
697 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
698  auto constName = nameMap.lookup(id).str();
699  if (constName.empty()) {
700  constName = "spirv_spec_const_" + std::to_string(id);
701  }
702  return constName;
703 }
704 
705 spirv::SpecConstantOp
706 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
707  TypedAttr defaultValue) {
708  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
709  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
710  defaultValue);
711  if (decorations.count(resultID)) {
712  for (auto attr : decorations[resultID].getAttrs())
713  op->setAttr(attr.getName(), attr.getValue());
714  }
715  specConstMap[resultID] = op;
716  return op;
717 }
718 
719 LogicalResult
720 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
721  unsigned wordIndex = 0;
722  if (operands.size() < 3) {
723  return emitError(
724  unknownLoc,
725  "OpVariable needs at least 3 operands, type, <id> and storage class");
726  }
727 
728  // Result Type.
729  auto type = getType(operands[wordIndex]);
730  if (!type) {
731  return emitError(unknownLoc, "unknown result type <id> : ")
732  << operands[wordIndex];
733  }
734  auto ptrType = dyn_cast<spirv::PointerType>(type);
735  if (!ptrType) {
736  return emitError(unknownLoc,
737  "expected a result type <id> to be a spirv.ptr, found : ")
738  << type;
739  }
740  wordIndex++;
741 
742  // Result <id>.
743  auto variableID = operands[wordIndex];
744  auto variableName = nameMap.lookup(variableID).str();
745  if (variableName.empty()) {
746  variableName = "spirv_var_" + std::to_string(variableID);
747  }
748  wordIndex++;
749 
750  // Storage class.
751  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
752  if (ptrType.getStorageClass() != storageClass) {
753  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
754  << type << " and that specified in OpVariable instruction : "
755  << stringifyStorageClass(storageClass);
756  }
757  wordIndex++;
758 
759  // Initializer.
760  FlatSymbolRefAttr initializer = nullptr;
761 
762  if (wordIndex < operands.size()) {
763  Operation *op = nullptr;
764 
765  if (auto initOp = getGlobalVariable(operands[wordIndex]))
766  op = initOp;
767  else if (auto initOp = getSpecConstant(operands[wordIndex]))
768  op = initOp;
769  else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
770  op = initOp;
771  else
772  return emitError(unknownLoc, "unknown <id> ")
773  << operands[wordIndex] << "used as initializer";
774 
775  initializer = SymbolRefAttr::get(op);
776  wordIndex++;
777  }
778  if (wordIndex != operands.size()) {
779  return emitError(unknownLoc,
780  "found more operands than expected when deserializing "
781  "OpVariable instruction, only ")
782  << wordIndex << " of " << operands.size() << " processed";
783  }
784  auto loc = createFileLineColLoc(opBuilder);
785  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
786  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
787  initializer);
788 
789  // Decorations.
790  if (decorations.count(variableID)) {
791  for (auto attr : decorations[variableID].getAttrs())
792  varOp->setAttr(attr.getName(), attr.getValue());
793  }
794  globalVariableMap[variableID] = varOp;
795  return success();
796 }
797 
798 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
799  auto constInfo = getConstant(id);
800  if (!constInfo) {
801  return nullptr;
802  }
803  return dyn_cast<IntegerAttr>(constInfo->first);
804 }
805 
806 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
807  if (operands.size() < 2) {
808  return emitError(unknownLoc, "OpName needs at least 2 operands");
809  }
810  if (!nameMap.lookup(operands[0]).empty()) {
811  return emitError(unknownLoc, "duplicate name found for result <id> ")
812  << operands[0];
813  }
814  unsigned wordIndex = 1;
815  StringRef name = decodeStringLiteral(operands, wordIndex);
816  if (wordIndex != operands.size()) {
817  return emitError(unknownLoc,
818  "unexpected trailing words in OpName instruction");
819  }
820  nameMap[operands[0]] = name;
821  return success();
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Type
826 //===----------------------------------------------------------------------===//
827 
828 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
829  ArrayRef<uint32_t> operands) {
830  if (operands.empty()) {
831  return emitError(unknownLoc, "type instruction with opcode ")
832  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
833  }
834 
835  /// TODO: Types might be forward declared in some instructions and need to be
836  /// handled appropriately.
837  if (typeMap.count(operands[0])) {
838  return emitError(unknownLoc, "duplicate definition for result <id> ")
839  << operands[0];
840  }
841 
842  switch (opcode) {
843  case spirv::Opcode::OpTypeVoid:
844  if (operands.size() != 1)
845  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
846  typeMap[operands[0]] = opBuilder.getNoneType();
847  break;
848  case spirv::Opcode::OpTypeBool:
849  if (operands.size() != 1)
850  return emitError(unknownLoc, "OpTypeBool must have no parameters");
851  typeMap[operands[0]] = opBuilder.getI1Type();
852  break;
853  case spirv::Opcode::OpTypeInt: {
854  if (operands.size() != 3)
855  return emitError(
856  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
857 
858  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
859  // to preserve or validate.
860  // 0 indicates unsigned, or no signedness semantics
861  // 1 indicates signed semantics."
862  //
863  // So we cannot differentiate signless and unsigned integers; always use
864  // signless semantics for such cases.
865  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
866  : IntegerType::SignednessSemantics::Signless;
867  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
868  } break;
869  case spirv::Opcode::OpTypeFloat: {
870  if (operands.size() != 2 && operands.size() != 3)
871  return emitError(unknownLoc,
872  "OpTypeFloat expects either 2 operands (type, bitwidth) "
873  "or 3 operands (type, bitwidth, encoding), but got ")
874  << operands.size();
875  uint32_t bitWidth = operands[1];
876 
877  Type floatTy;
878  switch (bitWidth) {
879  case 16:
880  floatTy = opBuilder.getF16Type();
881  break;
882  case 32:
883  floatTy = opBuilder.getF32Type();
884  break;
885  case 64:
886  floatTy = opBuilder.getF64Type();
887  break;
888  default:
889  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
890  << bitWidth;
891  }
892 
893  if (operands.size() == 3) {
894  if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
895  return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
896  << operands[2];
897  if (bitWidth != 16)
898  return emitError(unknownLoc,
899  "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
900  << bitWidth << " (expected 16)";
901  floatTy = opBuilder.getBF16Type();
902  }
903 
904  typeMap[operands[0]] = floatTy;
905  } break;
906  case spirv::Opcode::OpTypeVector: {
907  if (operands.size() != 3) {
908  return emitError(
909  unknownLoc,
910  "OpTypeVector must have element type and count parameters");
911  }
912  Type elementTy = getType(operands[1]);
913  if (!elementTy) {
914  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
915  << operands[1];
916  }
917  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
918  } break;
919  case spirv::Opcode::OpTypePointer: {
920  return processOpTypePointer(operands);
921  } break;
922  case spirv::Opcode::OpTypeArray:
923  return processArrayType(operands);
924  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
925  return processCooperativeMatrixTypeKHR(operands);
926  case spirv::Opcode::OpTypeFunction:
927  return processFunctionType(operands);
928  case spirv::Opcode::OpTypeImage:
929  return processImageType(operands);
930  case spirv::Opcode::OpTypeSampledImage:
931  return processSampledImageType(operands);
932  case spirv::Opcode::OpTypeRuntimeArray:
933  return processRuntimeArrayType(operands);
934  case spirv::Opcode::OpTypeStruct:
935  return processStructType(operands);
936  case spirv::Opcode::OpTypeMatrix:
937  return processMatrixType(operands);
938  default:
939  return emitError(unknownLoc, "unhandled type instruction");
940  }
941  return success();
942 }
943 
944 LogicalResult
945 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
946  if (operands.size() != 3)
947  return emitError(unknownLoc, "OpTypePointer must have two parameters");
948 
949  auto pointeeType = getType(operands[2]);
950  if (!pointeeType)
951  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
952  << operands[2];
953 
954  uint32_t typePointerID = operands[0];
955  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
956  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
957 
958  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
959  deferredStructIt != std::end(deferredStructTypesInfos);) {
960  for (auto *unresolvedMemberIt =
961  std::begin(deferredStructIt->unresolvedMemberTypes);
962  unresolvedMemberIt !=
963  std::end(deferredStructIt->unresolvedMemberTypes);) {
964  if (unresolvedMemberIt->first == typePointerID) {
965  // The newly constructed pointer type can resolve one of the
966  // deferred struct type members; update the memberTypes list and
967  // clean the unresolvedMemberTypes list accordingly.
968  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
969  typeMap[typePointerID];
970  unresolvedMemberIt =
971  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
972  } else {
973  ++unresolvedMemberIt;
974  }
975  }
976 
977  if (deferredStructIt->unresolvedMemberTypes.empty()) {
978  // All deferred struct type members are now resolved, set the struct body.
979  auto structType = deferredStructIt->deferredStructType;
980 
981  assert(structType && "expected a spirv::StructType");
982  assert(structType.isIdentified() && "expected an indentified struct");
983 
984  if (failed(structType.trySetBody(
985  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
986  deferredStructIt->memberDecorationsInfo)))
987  return failure();
988 
989  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
990  } else {
991  ++deferredStructIt;
992  }
993  }
994 
995  return success();
996 }
997 
998 LogicalResult
999 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
1000  if (operands.size() != 3) {
1001  return emitError(unknownLoc,
1002  "OpTypeArray must have element type and count parameters");
1003  }
1004 
1005  Type elementTy = getType(operands[1]);
1006  if (!elementTy) {
1007  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1008  << operands[1];
1009  }
1010 
1011  unsigned count = 0;
1012  // TODO: The count can also come frome a specialization constant.
1013  auto countInfo = getConstant(operands[2]);
1014  if (!countInfo) {
1015  return emitError(unknownLoc, "OpTypeArray count <id> ")
1016  << operands[2] << "can only come from normal constant right now";
1017  }
1018 
1019  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1020  count = intVal.getValue().getZExtValue();
1021  } else {
1022  return emitError(unknownLoc, "OpTypeArray count must come from a "
1023  "scalar integer constant instruction");
1024  }
1025 
1026  typeMap[operands[0]] = spirv::ArrayType::get(
1027  elementTy, count, typeDecorations.lookup(operands[0]));
1028  return success();
1029 }
1030 
1031 LogicalResult
1032 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1033  assert(!operands.empty() && "No operands for processing function type");
1034  if (operands.size() == 1) {
1035  return emitError(unknownLoc, "missing return type for OpTypeFunction");
1036  }
1037  auto returnType = getType(operands[1]);
1038  if (!returnType) {
1039  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1040  }
1041  SmallVector<Type, 1> argTypes;
1042  for (size_t i = 2, e = operands.size(); i < e; ++i) {
1043  auto ty = getType(operands[i]);
1044  if (!ty) {
1045  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1046  }
1047  argTypes.push_back(ty);
1048  }
1049  ArrayRef<Type> returnTypes;
1050  if (!isVoidType(returnType)) {
1051  returnTypes = llvm::ArrayRef(returnType);
1052  }
1053  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1054  return success();
1055 }
1056 
1057 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1058  ArrayRef<uint32_t> operands) {
1059  if (operands.size() != 6) {
1060  return emitError(unknownLoc,
1061  "OpTypeCooperativeMatrixKHR must have element type, "
1062  "scope, row and column parameters, and use");
1063  }
1064 
1065  Type elementTy = getType(operands[1]);
1066  if (!elementTy) {
1067  return emitError(unknownLoc,
1068  "OpTypeCooperativeMatrixKHR references undefined <id> ")
1069  << operands[1];
1070  }
1071 
1072  std::optional<spirv::Scope> scope =
1073  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1074  if (!scope) {
1075  return emitError(
1076  unknownLoc,
1077  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1078  << operands[2];
1079  }
1080 
1081  IntegerAttr rowsAttr = getConstantInt(operands[3]);
1082  IntegerAttr columnsAttr = getConstantInt(operands[4]);
1083  IntegerAttr useAttr = getConstantInt(operands[5]);
1084 
1085  if (!rowsAttr)
1086  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1087  "undefined constant <id> ")
1088  << operands[3];
1089 
1090  if (!columnsAttr)
1091  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1092  "references undefined constant <id> ")
1093  << operands[4];
1094 
1095  if (!useAttr)
1096  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1097  "undefined constant <id> ")
1098  << operands[5];
1099 
1100  unsigned rows = rowsAttr.getInt();
1101  unsigned columns = columnsAttr.getInt();
1102 
1103  std::optional<spirv::CooperativeMatrixUseKHR> use =
1104  spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1105  if (!use) {
1106  return emitError(
1107  unknownLoc,
1108  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1109  << operands[5];
1110  }
1111 
1112  typeMap[operands[0]] =
1113  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1114  return success();
1115 }
1116 
1117 LogicalResult
1118 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1119  if (operands.size() != 2) {
1120  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1121  }
1122  Type memberType = getType(operands[1]);
1123  if (!memberType) {
1124  return emitError(unknownLoc,
1125  "OpTypeRuntimeArray references undefined <id> ")
1126  << operands[1];
1127  }
1128  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1129  memberType, typeDecorations.lookup(operands[0]));
1130  return success();
1131 }
1132 
1133 LogicalResult
1134 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1135  // TODO: Find a way to handle identified structs when debug info is stripped.
1136 
1137  if (operands.empty()) {
1138  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1139  }
1140 
1141  if (operands.size() == 1) {
1142  // Handle empty struct.
1143  typeMap[operands[0]] =
1144  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1145  return success();
1146  }
1147 
1148  // First element is operand ID, second element is member index in the struct.
1149  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1150  SmallVector<Type, 4> memberTypes;
1151 
1152  for (auto op : llvm::drop_begin(operands, 1)) {
1153  Type memberType = getType(op);
1154  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1155 
1156  if (!memberType && !typeForwardPtr)
1157  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1158  << op;
1159 
1160  if (!memberType)
1161  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1162 
1163  memberTypes.push_back(memberType);
1164  }
1165 
1168  if (memberDecorationMap.count(operands[0])) {
1169  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1170  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1171  if (allMemberDecorations.count(memberIndex)) {
1172  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1173  // Check for offset.
1174  if (memberDecoration.first == spirv::Decoration::Offset) {
1175  // If offset info is empty, resize to the number of members;
1176  if (offsetInfo.empty()) {
1177  offsetInfo.resize(memberTypes.size());
1178  }
1179  offsetInfo[memberIndex] = memberDecoration.second[0];
1180  } else {
1181  if (!memberDecoration.second.empty()) {
1182  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1183  memberDecoration.first,
1184  memberDecoration.second[0]);
1185  } else {
1186  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1187  memberDecoration.first, 0);
1188  }
1189  }
1190  }
1191  }
1192  }
1193  }
1194 
1195  uint32_t structID = operands[0];
1196  std::string structIdentifier = nameMap.lookup(structID).str();
1197 
1198  if (structIdentifier.empty()) {
1199  assert(unresolvedMemberTypes.empty() &&
1200  "didn't expect unresolved member types");
1201  typeMap[structID] =
1202  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1203  } else {
1204  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1205  typeMap[structID] = structTy;
1206 
1207  if (!unresolvedMemberTypes.empty())
1208  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1209  memberTypes, offsetInfo,
1210  memberDecorationsInfo});
1211  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1212  memberDecorationsInfo)))
1213  return failure();
1214  }
1215 
1216  // TODO: Update StructType to have member name as attribute as
1217  // well.
1218  return success();
1219 }
1220 
1221 LogicalResult
1222 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1223  if (operands.size() != 3) {
1224  // Three operands are needed: result_id, column_type, and column_count
1225  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1226  " (result_id, column_type, and column_count)");
1227  }
1228  // Matrix columns must be of vector type
1229  Type elementTy = getType(operands[1]);
1230  if (!elementTy) {
1231  return emitError(unknownLoc,
1232  "OpTypeMatrix references undefined column type.")
1233  << operands[1];
1234  }
1235 
1236  uint32_t colsCount = operands[2];
1237  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1238  return success();
1239 }
1240 
1241 LogicalResult
1242 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1243  if (operands.size() != 2)
1244  return emitError(unknownLoc,
1245  "OpTypeForwardPointer instruction must have two operands");
1246 
1247  typeForwardPointerIDs.insert(operands[0]);
1248  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1249  // instruction that defines the actual type.
1250 
1251  return success();
1252 }
1253 
1254 LogicalResult
1255 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1256  // TODO: Add support for Access Qualifier.
1257  if (operands.size() != 8)
1258  return emitError(
1259  unknownLoc,
1260  "OpTypeImage with non-eight operands are not supported yet");
1261 
1262  Type elementTy = getType(operands[1]);
1263  if (!elementTy)
1264  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1265  << operands[1];
1266 
1267  auto dim = spirv::symbolizeDim(operands[2]);
1268  if (!dim)
1269  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1270  << operands[2];
1271 
1272  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1273  if (!depthInfo)
1274  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1275  << operands[3];
1276 
1277  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1278  if (!arrayedInfo)
1279  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1280  << operands[4];
1281 
1282  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1283  if (!samplingInfo)
1284  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1285 
1286  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1287  if (!samplerUseInfo)
1288  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1289  << operands[6];
1290 
1291  auto format = spirv::symbolizeImageFormat(operands[7]);
1292  if (!format)
1293  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1294  << operands[7];
1295 
1296  typeMap[operands[0]] = spirv::ImageType::get(
1297  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1298  samplingInfo.value(), samplerUseInfo.value(), format.value());
1299  return success();
1300 }
1301 
1302 LogicalResult
1303 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1304  if (operands.size() != 2)
1305  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1306 
1307  Type elementTy = getType(operands[1]);
1308  if (!elementTy)
1309  return emitError(unknownLoc,
1310  "OpTypeSampledImage references undefined <id>: ")
1311  << operands[1];
1312 
1313  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1314  return success();
1315 }
1316 
1317 //===----------------------------------------------------------------------===//
1318 // Constant
1319 //===----------------------------------------------------------------------===//
1320 
1321 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1322  bool isSpec) {
1323  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1324 
1325  if (operands.size() < 2) {
1326  return emitError(unknownLoc)
1327  << opname << " must have type <id> and result <id>";
1328  }
1329  if (operands.size() < 3) {
1330  return emitError(unknownLoc)
1331  << opname << " must have at least 1 more parameter";
1332  }
1333 
1334  Type resultType = getType(operands[0]);
1335  if (!resultType) {
1336  return emitError(unknownLoc, "undefined result type from <id> ")
1337  << operands[0];
1338  }
1339 
1340  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1341  if (bitwidth == 64) {
1342  if (operands.size() == 4) {
1343  return success();
1344  }
1345  return emitError(unknownLoc)
1346  << opname << " should have 2 parameters for 64-bit values";
1347  }
1348  if (bitwidth <= 32) {
1349  if (operands.size() == 3) {
1350  return success();
1351  }
1352 
1353  return emitError(unknownLoc)
1354  << opname
1355  << " should have 1 parameter for values with no more than 32 bits";
1356  }
1357  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1358  << bitwidth;
1359  };
1360 
1361  auto resultID = operands[1];
1362 
1363  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1364  auto bitwidth = intType.getWidth();
1365  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1366  return failure();
1367  }
1368 
1369  APInt value;
1370  if (bitwidth == 64) {
1371  // 64-bit integers are represented with two SPIR-V words. According to
1372  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1373  // literal’s low-order words appear first."
1374  struct DoubleWord {
1375  uint32_t word1;
1376  uint32_t word2;
1377  } words = {operands[2], operands[3]};
1378  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1379  } else if (bitwidth <= 32) {
1380  value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1381  /*implicitTrunc=*/true);
1382  }
1383 
1384  auto attr = opBuilder.getIntegerAttr(intType, value);
1385 
1386  if (isSpec) {
1387  createSpecConstant(unknownLoc, resultID, attr);
1388  } else {
1389  // For normal constants, we just record the attribute (and its type) for
1390  // later materialization at use sites.
1391  constantMap.try_emplace(resultID, attr, intType);
1392  }
1393 
1394  return success();
1395  }
1396 
1397  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1398  auto bitwidth = floatType.getWidth();
1399  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1400  return failure();
1401  }
1402 
1403  APFloat value(0.f);
1404  if (floatType.isF64()) {
1405  // Double values are represented with two SPIR-V words. According to
1406  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1407  // literal’s low-order words appear first."
1408  struct DoubleWord {
1409  uint32_t word1;
1410  uint32_t word2;
1411  } words = {operands[2], operands[3]};
1412  value = APFloat(llvm::bit_cast<double>(words));
1413  } else if (floatType.isF32()) {
1414  value = APFloat(llvm::bit_cast<float>(operands[2]));
1415  } else if (floatType.isF16()) {
1416  APInt data(16, operands[2]);
1417  value = APFloat(APFloat::IEEEhalf(), data);
1418  } else if (floatType.isBF16()) {
1419  APInt data(16, operands[2]);
1420  value = APFloat(APFloat::BFloat(), data);
1421  }
1422 
1423  auto attr = opBuilder.getFloatAttr(floatType, value);
1424  if (isSpec) {
1425  createSpecConstant(unknownLoc, resultID, attr);
1426  } else {
1427  // For normal constants, we just record the attribute (and its type) for
1428  // later materialization at use sites.
1429  constantMap.try_emplace(resultID, attr, floatType);
1430  }
1431 
1432  return success();
1433  }
1434 
1435  return emitError(unknownLoc, "OpConstant can only generate values of "
1436  "scalar integer or floating-point type");
1437 }
1438 
1439 LogicalResult spirv::Deserializer::processConstantBool(
1440  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1441  if (operands.size() != 2) {
1442  return emitError(unknownLoc, "Op")
1443  << (isSpec ? "Spec" : "") << "Constant"
1444  << (isTrue ? "True" : "False")
1445  << " must have type <id> and result <id>";
1446  }
1447 
1448  auto attr = opBuilder.getBoolAttr(isTrue);
1449  auto resultID = operands[1];
1450  if (isSpec) {
1451  createSpecConstant(unknownLoc, resultID, attr);
1452  } else {
1453  // For normal constants, we just record the attribute (and its type) for
1454  // later materialization at use sites.
1455  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1456  }
1457 
1458  return success();
1459 }
1460 
1461 LogicalResult
1462 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1463  if (operands.size() < 2) {
1464  return emitError(unknownLoc,
1465  "OpConstantComposite must have type <id> and result <id>");
1466  }
1467  if (operands.size() < 3) {
1468  return emitError(unknownLoc,
1469  "OpConstantComposite must have at least 1 parameter");
1470  }
1471 
1472  Type resultType = getType(operands[0]);
1473  if (!resultType) {
1474  return emitError(unknownLoc, "undefined result type from <id> ")
1475  << operands[0];
1476  }
1477 
1478  SmallVector<Attribute, 4> elements;
1479  elements.reserve(operands.size() - 2);
1480  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1481  auto elementInfo = getConstant(operands[i]);
1482  if (!elementInfo) {
1483  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1484  << operands[i] << " must come from a normal constant";
1485  }
1486  elements.push_back(elementInfo->first);
1487  }
1488 
1489  auto resultID = operands[1];
1490  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1491  auto attr = DenseElementsAttr::get(shapedType, elements);
1492  // For normal constants, we just record the attribute (and its type) for
1493  // later materialization at use sites.
1494  constantMap.try_emplace(resultID, attr, shapedType);
1495  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1496  auto attr = opBuilder.getArrayAttr(elements);
1497  constantMap.try_emplace(resultID, attr, resultType);
1498  } else {
1499  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1500  << resultType;
1501  }
1502 
1503  return success();
1504 }
1505 
1506 LogicalResult
1507 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1508  if (operands.size() < 2) {
1509  return emitError(unknownLoc,
1510  "OpConstantComposite must have type <id> and result <id>");
1511  }
1512  if (operands.size() < 3) {
1513  return emitError(unknownLoc,
1514  "OpConstantComposite must have at least 1 parameter");
1515  }
1516 
1517  Type resultType = getType(operands[0]);
1518  if (!resultType) {
1519  return emitError(unknownLoc, "undefined result type from <id> ")
1520  << operands[0];
1521  }
1522 
1523  auto resultID = operands[1];
1524  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1525 
1526  SmallVector<Attribute, 4> elements;
1527  elements.reserve(operands.size() - 2);
1528  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1529  auto elementInfo = getSpecConstant(operands[i]);
1530  elements.push_back(SymbolRefAttr::get(elementInfo));
1531  }
1532 
1533  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1534  unknownLoc, TypeAttr::get(resultType), symName,
1535  opBuilder.getArrayAttr(elements));
1536  specConstCompositeMap[resultID] = op;
1537 
1538  return success();
1539 }
1540 
1541 LogicalResult
1542 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1543  if (operands.size() < 3)
1544  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1545  "result <id>, and operand opcode");
1546 
1547  uint32_t resultTypeID = operands[0];
1548 
1549  if (!getType(resultTypeID))
1550  return emitError(unknownLoc, "undefined result type from <id> ")
1551  << resultTypeID;
1552 
1553  uint32_t resultID = operands[1];
1554  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1555  auto emplaceResult = specConstOperationMap.try_emplace(
1556  resultID,
1557  SpecConstOperationMaterializationInfo{
1558  enclosedOpcode, resultTypeID,
1559  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1560 
1561  if (!emplaceResult.second)
1562  return emitError(unknownLoc, "value with <id>: ")
1563  << resultID << " is probably defined before.";
1564 
1565  return success();
1566 }
1567 
1568 Value spirv::Deserializer::materializeSpecConstantOperation(
1569  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1570  ArrayRef<uint32_t> enclosedOpOperands) {
1571 
1572  Type resultType = getType(resultTypeID);
1573 
1574  // Instructions wrapped by OpSpecConstantOp need an ID for their
1575  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1576  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1577  // ID in that map is assigned to the result of the enclosed instruction. Note
1578  // that there is no need to update this fake ID since we only need to
1579  // reference the created Value for the enclosed op from the spv::YieldOp
1580  // created later in this method (both of which are the only values in their
1581  // region: the SpecConstantOperation's region). If we encounter another
1582  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1583  // previous Value assigned to it isn't visible in the current scope anyway.
1584  DenseMap<uint32_t, Value> newValueMap;
1585  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1586  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1587 
1588  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1589  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1590  enclosedOpResultTypeAndOperands.push_back(fakeID);
1591  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1592  enclosedOpOperands.end());
1593 
1594  // Process enclosed instruction before creating the enclosing
1595  // specConstantOperation (and its region). This way, references to constants,
1596  // global variables, and spec constants will be materialized outside the new
1597  // op's region. For more info, see Deserializer::getValue's implementation.
1598  if (failed(
1599  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1600  return Value();
1601 
1602  // Since the enclosed op is emitted in the current block, split it in a
1603  // separate new block.
1604  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1605 
1606  auto loc = createFileLineColLoc(opBuilder);
1607  auto specConstOperationOp =
1608  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1609 
1610  Region &body = specConstOperationOp.getBody();
1611  // Move the new block into SpecConstantOperation's body.
1612  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1613  Region::iterator(enclosedBlock));
1614  Block &block = body.back();
1615 
1616  // RAII guard to reset the insertion point to the module's region after
1617  // deserializing the body of the specConstantOperation.
1618  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1619  opBuilder.setInsertionPointToEnd(&block);
1620 
1621  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1622  return specConstOperationOp.getResult();
1623 }
1624 
1625 LogicalResult
1626 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1627  if (operands.size() != 2) {
1628  return emitError(unknownLoc,
1629  "OpConstantNull must have type <id> and result <id>");
1630  }
1631 
1632  Type resultType = getType(operands[0]);
1633  if (!resultType) {
1634  return emitError(unknownLoc, "undefined result type from <id> ")
1635  << operands[0];
1636  }
1637 
1638  auto resultID = operands[1];
1639  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1640  auto attr = opBuilder.getZeroAttr(resultType);
1641  // For normal constants, we just record the attribute (and its type) for
1642  // later materialization at use sites.
1643  constantMap.try_emplace(resultID, attr, resultType);
1644  return success();
1645  }
1646 
1647  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1648  << resultType;
1649 }
1650 
1651 //===----------------------------------------------------------------------===//
1652 // Control flow
1653 //===----------------------------------------------------------------------===//
1654 
1655 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1656  if (auto *block = getBlock(id)) {
1657  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1658  << " @ " << block << "\n");
1659  return block;
1660  }
1661 
1662  // We don't know where this block will be placed finally (in a
1663  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1664  // function for now and sort out the proper place later.
1665  auto *block = curFunction->addBlock();
1666  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1667  << " @ " << block << "\n");
1668  return blockMap[id] = block;
1669 }
1670 
1671 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1672  if (!curBlock) {
1673  return emitError(unknownLoc, "OpBranch must appear inside a block");
1674  }
1675 
1676  if (operands.size() != 1) {
1677  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1678  }
1679 
1680  auto *target = getOrCreateBlock(operands[0]);
1681  auto loc = createFileLineColLoc(opBuilder);
1682  // The preceding instruction for the OpBranch instruction could be an
1683  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1684  // the same OpLine information.
1685  opBuilder.create<spirv::BranchOp>(loc, target);
1686 
1687  clearDebugLine();
1688  return success();
1689 }
1690 
1691 LogicalResult
1692 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1693  if (!curBlock) {
1694  return emitError(unknownLoc,
1695  "OpBranchConditional must appear inside a block");
1696  }
1697 
1698  if (operands.size() != 3 && operands.size() != 5) {
1699  return emitError(unknownLoc,
1700  "OpBranchConditional must have condition, true label, "
1701  "false label, and optionally two branch weights");
1702  }
1703 
1704  auto condition = getValue(operands[0]);
1705  auto *trueBlock = getOrCreateBlock(operands[1]);
1706  auto *falseBlock = getOrCreateBlock(operands[2]);
1707 
1708  std::optional<std::pair<uint32_t, uint32_t>> weights;
1709  if (operands.size() == 5) {
1710  weights = std::make_pair(operands[3], operands[4]);
1711  }
1712  // The preceding instruction for the OpBranchConditional instruction could be
1713  // an OpSelectionMerge instruction, in this case they will have the same
1714  // OpLine information.
1715  auto loc = createFileLineColLoc(opBuilder);
1716  opBuilder.create<spirv::BranchConditionalOp>(
1717  loc, condition, trueBlock,
1718  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1719  /*falseArguments=*/ArrayRef<Value>(), weights);
1720 
1721  clearDebugLine();
1722  return success();
1723 }
1724 
1725 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1726  if (!curFunction) {
1727  return emitError(unknownLoc, "OpLabel must appear inside a function");
1728  }
1729 
1730  if (operands.size() != 1) {
1731  return emitError(unknownLoc, "OpLabel should only have result <id>");
1732  }
1733 
1734  auto labelID = operands[0];
1735  // We may have forward declared this block.
1736  auto *block = getOrCreateBlock(labelID);
1737  LLVM_DEBUG(logger.startLine()
1738  << "[block] populating block " << block << "\n");
1739  // If we have seen this block, make sure it was just a forward declaration.
1740  assert(block->empty() && "re-deserialize the same block!");
1741 
1742  opBuilder.setInsertionPointToStart(block);
1743  blockMap[labelID] = curBlock = block;
1744 
1745  return success();
1746 }
1747 
1748 LogicalResult
1749 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1750  if (!curBlock) {
1751  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1752  }
1753 
1754  if (operands.size() < 2) {
1755  return emitError(
1756  unknownLoc,
1757  "OpSelectionMerge must specify merge target and selection control");
1758  }
1759 
1760  auto *mergeBlock = getOrCreateBlock(operands[0]);
1761  auto loc = createFileLineColLoc(opBuilder);
1762  auto selectionControl = operands[1];
1763 
1764  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1765  .second) {
1766  return emitError(
1767  unknownLoc,
1768  "a block cannot have more than one OpSelectionMerge instruction");
1769  }
1770 
1771  return success();
1772 }
1773 
1774 LogicalResult
1775 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1776  if (!curBlock) {
1777  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1778  }
1779 
1780  if (operands.size() < 3) {
1781  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1782  "continue target and loop control");
1783  }
1784 
1785  auto *mergeBlock = getOrCreateBlock(operands[0]);
1786  auto *continueBlock = getOrCreateBlock(operands[1]);
1787  auto loc = createFileLineColLoc(opBuilder);
1788  uint32_t loopControl = operands[2];
1789 
1790  if (!blockMergeInfo
1791  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1792  .second) {
1793  return emitError(
1794  unknownLoc,
1795  "a block cannot have more than one OpLoopMerge instruction");
1796  }
1797 
1798  return success();
1799 }
1800 
1801 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1802  if (!curBlock) {
1803  return emitError(unknownLoc, "OpPhi must appear in a block");
1804  }
1805 
1806  if (operands.size() < 4) {
1807  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1808  "and variable-parent pairs");
1809  }
1810 
1811  // Create a block argument for this OpPhi instruction.
1812  Type blockArgType = getType(operands[0]);
1813  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1814  valueMap[operands[1]] = blockArg;
1815  LLVM_DEBUG(logger.startLine()
1816  << "[phi] created block argument " << blockArg
1817  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1818 
1819  // For each (value, predecessor) pair, insert the value to the predecessor's
1820  // blockPhiInfo entry so later we can fix the block argument there.
1821  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1822  uint32_t value = operands[i];
1823  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1824  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1825  blockPhiInfo[predecessorTargetPair].push_back(value);
1826  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1827  << " with arg id = " << value << "\n");
1828  }
1829 
1830  return success();
1831 }
1832 
1833 namespace {
1834 /// A class for putting all blocks in a structured selection/loop in a
1835 /// spirv.mlir.selection/spirv.mlir.loop op.
1836 class ControlFlowStructurizer {
1837 public:
1838 #ifndef NDEBUG
1839  ControlFlowStructurizer(Location loc, uint32_t control,
1840  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1841  Block *merge, Block *cont,
1842  llvm::ScopedPrinter &logger)
1843  : location(loc), control(control), blockMergeInfo(mergeInfo),
1844  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1845  logger(logger) {}
1846 #else
1847  ControlFlowStructurizer(Location loc, uint32_t control,
1848  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1849  Block *merge, Block *cont)
1850  : location(loc), control(control), blockMergeInfo(mergeInfo),
1851  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1852 #endif
1853 
1854  /// Structurizes the loop at the given `headerBlock`.
1855  ///
1856  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1857  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1858  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1859  /// method will also update `mergeInfo` by remapping all blocks inside to the
1860  /// newly cloned ones inside structured control flow op's regions.
1861  LogicalResult structurize();
1862 
1863 private:
1864  /// Creates a new spirv.mlir.selection op at the beginning of the
1865  /// `mergeBlock`.
1866  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1867 
1868  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1869  spirv::LoopOp createLoopOp(uint32_t loopControl);
1870 
1871  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1872  void collectBlocksInConstruct();
1873 
1874  Location location;
1875  uint32_t control;
1876 
1877  spirv::BlockMergeInfoMap &blockMergeInfo;
1878 
1879  Block *headerBlock;
1880  Block *mergeBlock;
1881  Block *continueBlock; // nullptr for spirv.mlir.selection
1882 
1883  SetVector<Block *> constructBlocks;
1884 
1885 #ifndef NDEBUG
1886  /// A logger used to emit information during the deserialzation process.
1887  llvm::ScopedPrinter &logger;
1888 #endif
1889 };
1890 } // namespace
1891 
1892 spirv::SelectionOp
1893 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1894  // Create a builder and set the insertion point to the beginning of the
1895  // merge block so that the newly created SelectionOp will be inserted there.
1896  OpBuilder builder(&mergeBlock->front());
1897 
1898  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1899  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1900  selectionOp.addMergeBlock(builder);
1901 
1902  return selectionOp;
1903 }
1904 
1905 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1906  // Create a builder and set the insertion point to the beginning of the
1907  // merge block so that the newly created LoopOp will be inserted there.
1908  OpBuilder builder(&mergeBlock->front());
1909 
1910  auto control = static_cast<spirv::LoopControl>(loopControl);
1911  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1912  loopOp.addEntryAndMergeBlock(builder);
1913 
1914  return loopOp;
1915 }
1916 
1917 void ControlFlowStructurizer::collectBlocksInConstruct() {
1918  assert(constructBlocks.empty() && "expected empty constructBlocks");
1919 
1920  // Put the header block in the work list first.
1921  constructBlocks.insert(headerBlock);
1922 
1923  // For each item in the work list, add its successors excluding the merge
1924  // block.
1925  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1926  for (auto *successor : constructBlocks[i]->getSuccessors())
1927  if (successor != mergeBlock)
1928  constructBlocks.insert(successor);
1929  }
1930 }
1931 
1932 LogicalResult ControlFlowStructurizer::structurize() {
1933  Operation *op = nullptr;
1934  bool isLoop = continueBlock != nullptr;
1935  if (isLoop) {
1936  if (auto loopOp = createLoopOp(control))
1937  op = loopOp.getOperation();
1938  } else {
1939  if (auto selectionOp = createSelectionOp(control))
1940  op = selectionOp.getOperation();
1941  }
1942  if (!op)
1943  return failure();
1944  Region &body = op->getRegion(0);
1945 
1946  IRMapping mapper;
1947  // All references to the old merge block should be directed to the
1948  // selection/loop merge block in the SelectionOp/LoopOp's region.
1949  mapper.map(mergeBlock, &body.back());
1950 
1951  collectBlocksInConstruct();
1952 
1953  // We've identified all blocks belonging to the selection/loop's region. Now
1954  // need to "move" them into the selection/loop. Instead of really moving the
1955  // blocks, in the following we copy them and remap all values and branches.
1956  // This is because:
1957  // * Inserting a block into a region requires the block not in any region
1958  // before. But selections/loops can nest so we can create selection/loop ops
1959  // in a nested manner, which means some blocks may already be in a
1960  // selection/loop region when to be moved again.
1961  // * It's much trickier to fix up the branches into and out of the loop's
1962  // region: we need to treat not-moved blocks and moved blocks differently:
1963  // Not-moved blocks jumping to the loop header block need to jump to the
1964  // merge point containing the new loop op but not the loop continue block's
1965  // back edge. Moved blocks jumping out of the loop need to jump to the
1966  // merge block inside the loop region but not other not-moved blocks.
1967  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1968  // logic.
1969 
1970  // Create a corresponding block in the SelectionOp/LoopOp's region for each
1971  // block in this loop construct.
1972  OpBuilder builder(body);
1973  for (auto *block : constructBlocks) {
1974  // Create a block and insert it before the selection/loop merge block in the
1975  // SelectionOp/LoopOp's region.
1976  auto *newBlock = builder.createBlock(&body.back());
1977  mapper.map(block, newBlock);
1978  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1979  << " from block " << block << "\n");
1980  if (!isFnEntryBlock(block)) {
1981  for (BlockArgument blockArg : block->getArguments()) {
1982  auto newArg =
1983  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1984  mapper.map(blockArg, newArg);
1985  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1986  << blockArg << " to " << newArg << "\n");
1987  }
1988  } else {
1989  LLVM_DEBUG(logger.startLine()
1990  << "[cf] block " << block << " is a function entry block\n");
1991  }
1992 
1993  for (auto &op : *block)
1994  newBlock->push_back(op.clone(mapper));
1995  }
1996 
1997  // Go through all ops and remap the operands.
1998  auto remapOperands = [&](Operation *op) {
1999  for (auto &operand : op->getOpOperands())
2000  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2001  operand.set(mappedOp);
2002  for (auto &succOp : op->getBlockOperands())
2003  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2004  succOp.set(mappedOp);
2005  };
2006  for (auto &block : body)
2007  block.walk(remapOperands);
2008 
2009  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2010  // the selection/loop construct into its region. Next we need to fix the
2011  // connections between this new SelectionOp/LoopOp with existing blocks.
2012 
2013  // All existing incoming branches should go to the merge block, where the
2014  // SelectionOp/LoopOp resides right now.
2015  headerBlock->replaceAllUsesWith(mergeBlock);
2016 
2017  LLVM_DEBUG({
2018  logger.startLine() << "[cf] after cloning and fixing references:\n";
2019  headerBlock->getParentOp()->print(logger.getOStream());
2020  logger.startLine() << "\n";
2021  });
2022 
2023  if (isLoop) {
2024  if (!mergeBlock->args_empty()) {
2025  return mergeBlock->getParentOp()->emitError(
2026  "OpPhi in loop merge block unsupported");
2027  }
2028 
2029  // The loop header block may have block arguments. Since now we place the
2030  // loop op inside the old merge block, we need to make sure the old merge
2031  // block has the same block argument list.
2032  for (BlockArgument blockArg : headerBlock->getArguments())
2033  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2034 
2035  // If the loop header block has block arguments, make sure the spirv.Branch
2036  // op matches.
2037  SmallVector<Value, 4> blockArgs;
2038  if (!headerBlock->args_empty())
2039  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2040 
2041  // The loop entry block should have a unconditional branch jumping to the
2042  // loop header block.
2043  builder.setInsertionPointToEnd(&body.front());
2044  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
2045  ArrayRef<Value>(blockArgs));
2046  }
2047 
2048  // Values defined inside the selection region that need to be yielded outside
2049  // the region.
2050  SmallVector<Value> valuesToYield;
2051  // Outside uses of values that were sunk into the selection region. Those uses
2052  // will be replaced with values returned by the SelectionOp.
2053  SmallVector<Value> outsideUses;
2054 
2055  // Move block arguments of the original block (`mergeBlock`) into the merge
2056  // block inside the selection (`body.back()`). Values produced by block
2057  // arguments will be yielded by the selection region. We do not update uses or
2058  // erase original block arguments yet. It will be done later in the code.
2059  //
2060  // Code below is not executed for loops as it would interfere with the logic
2061  // above. Currently block arguments in the merge block are not supported, but
2062  // instead, the code above copies those arguments from the header block into
2063  // the merge block. As such, running the code would yield those copied
2064  // arguments that is most likely not a desired behaviour. This may need to be
2065  // revisited in the future.
2066  if (!isLoop)
2067  for (BlockArgument blockArg : mergeBlock->getArguments()) {
2068  // Create new block arguments in the last block ("merge block") of the
2069  // selection region. We create one argument for each argument in
2070  // `mergeBlock`. This new value will need to be yielded, and the original
2071  // value replaced, so add them to appropriate vectors.
2072  body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2073  valuesToYield.push_back(body.back().getArguments().back());
2074  outsideUses.push_back(blockArg);
2075  }
2076 
2077  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2078  // cleaned up.
2079  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2080  // First we need to drop all operands' references inside all blocks. This is
2081  // needed because we can have blocks referencing SSA values from one another.
2082  for (auto *block : constructBlocks)
2083  block->dropAllReferences();
2084 
2085  // All internal uses should be removed from original blocks by now, so
2086  // whatever is left is an outside use and will need to be yielded from
2087  // the newly created selection / loop region.
2088  for (Block *block : constructBlocks) {
2089  for (Operation &op : *block) {
2090  if (!op.use_empty())
2091  for (Value result : op.getResults()) {
2092  valuesToYield.push_back(mapper.lookupOrNull(result));
2093  outsideUses.push_back(result);
2094  }
2095  }
2096  for (BlockArgument &arg : block->getArguments()) {
2097  if (!arg.use_empty()) {
2098  valuesToYield.push_back(mapper.lookupOrNull(arg));
2099  outsideUses.push_back(arg);
2100  }
2101  }
2102  }
2103 
2104  assert(valuesToYield.size() == outsideUses.size());
2105 
2106  // If we need to yield any values from the selection / loop region we will
2107  // take care of it here.
2108  if (!valuesToYield.empty()) {
2109  LLVM_DEBUG(logger.startLine()
2110  << "[cf] yielding values from the selection / loop region\n");
2111 
2112  // Update `mlir.merge` with values to be yield.
2113  auto mergeOps = body.back().getOps<spirv::MergeOp>();
2114  Operation *merge = llvm::getSingleElement(mergeOps);
2115  assert(merge);
2116  merge->setOperands(valuesToYield);
2117 
2118  // MLIR does not allow changing the number of results of an operation, so
2119  // we create a new SelectionOp / LoopOp with required list of results and
2120  // move the region from the initial SelectionOp / LoopOp. The initial
2121  // operation is then removed. Since we move the region to the new op all
2122  // links between blocks and remapping we have previously done should be
2123  // preserved.
2124  builder.setInsertionPoint(&mergeBlock->front());
2125 
2126  Operation *newOp = nullptr;
2127 
2128  if (isLoop)
2129  newOp = builder.create<spirv::LoopOp>(
2130  location, TypeRange(ValueRange(outsideUses)),
2131  static_cast<spirv::LoopControl>(control));
2132  else
2133  newOp = builder.create<spirv::SelectionOp>(
2134  location, TypeRange(ValueRange(outsideUses)),
2135  static_cast<spirv::SelectionControl>(control));
2136 
2137  newOp->getRegion(0).takeBody(body);
2138 
2139  // Remove initial op and swap the pointer to the newly created one.
2140  op->erase();
2141  op = newOp;
2142 
2143  // Update all outside uses to use results of the SelectionOp / LoopOp and
2144  // remove block arguments from the original merge block.
2145  for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2146  outsideUses[i].replaceAllUsesWith(op->getResult(i));
2147 
2148  // We do not support block arguments in loop merge block. Also running this
2149  // function with loop would break some of the loop specific code above
2150  // dealing with block arguments.
2151  if (!isLoop)
2152  mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2153  }
2154 
2155  // Check that whether some op in the to-be-erased blocks still has uses. Those
2156  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2157  // region. We cannot handle such cases given that once a value is sinked into
2158  // the SelectionOp/LoopOp's region, there is no escape for it.
2159  for (auto *block : constructBlocks) {
2160  for (Operation &op : *block)
2161  if (!op.use_empty())
2162  return op.emitOpError("failed control flow structurization: value has "
2163  "uses outside of the "
2164  "enclosing selection/loop construct");
2165  for (BlockArgument &arg : block->getArguments())
2166  if (!arg.use_empty())
2167  return emitError(arg.getLoc(), "failed control flow structurization: "
2168  "block argument has uses outside of the "
2169  "enclosing selection/loop construct");
2170  }
2171 
2172  // Then erase all old blocks.
2173  for (auto *block : constructBlocks) {
2174  // We've cloned all blocks belonging to this construct into the structured
2175  // control flow op's region. Among these blocks, some may compose another
2176  // selection/loop. If so, they will be recorded within blockMergeInfo.
2177  // We need to update the pointers there to the newly remapped ones so we can
2178  // continue structurizing them later.
2179  //
2180  // We need to walk each block as constructBlocks do not include blocks
2181  // internal to ops already structured within those blocks. It is not
2182  // fully clear to me why the mergeInfo of blocks (yet to be structured)
2183  // inside already structured selections/loops get invalidated and needs
2184  // updating, however the following example code can cause a crash (depending
2185  // on the structuring order), when the most inner selection is being
2186  // structured after the outer selection and loop have been already
2187  // structured:
2188  //
2189  // spirv.mlir.for {
2190  // // ...
2191  // spirv.mlir.selection {
2192  // // ..
2193  // // A selection region that hasn't been yet structured!
2194  // // ..
2195  // }
2196  // // ...
2197  // }
2198  //
2199  // If the loop gets structured after the outer selection, but before the
2200  // inner selection. Moving the already structured selection inside the loop
2201  // will invalidate the mergeInfo of the region that is not yet structured.
2202  // Just going over constructBlocks will not check and updated header blocks
2203  // inside the already structured selection region. Walking block fixes that.
2204  //
2205  // TODO: If structuring was done in a fixed order starting with inner
2206  // most constructs this most likely not be an issue and the whole code
2207  // section could be removed. However, with the current non-deterministic
2208  // order this is not possible.
2209  //
2210  // TODO: The asserts in the following assumes input SPIR-V blob forms
2211  // correctly nested selection/loop constructs. We should relax this and
2212  // support error cases better.
2213  auto updateMergeInfo = [&](Block *block) -> WalkResult {
2214  auto it = blockMergeInfo.find(block);
2215  if (it != blockMergeInfo.end()) {
2216  // Use the original location for nested selection/loop ops.
2217  Location loc = it->second.loc;
2218 
2219  Block *newHeader = mapper.lookupOrNull(block);
2220  if (!newHeader)
2221  return emitError(loc, "failed control flow structurization: nested "
2222  "loop header block should be remapped!");
2223 
2224  Block *newContinue = it->second.continueBlock;
2225  if (newContinue) {
2226  newContinue = mapper.lookupOrNull(newContinue);
2227  if (!newContinue)
2228  return emitError(loc, "failed control flow structurization: nested "
2229  "loop continue block should be remapped!");
2230  }
2231 
2232  Block *newMerge = it->second.mergeBlock;
2233  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2234  newMerge = mappedTo;
2235 
2236  // The iterator should be erased before adding a new entry into
2237  // blockMergeInfo to avoid iterator invalidation.
2238  blockMergeInfo.erase(it);
2239  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2240  newContinue);
2241  }
2242 
2243  return WalkResult::advance();
2244  };
2245 
2246  if (block->walk(updateMergeInfo).wasInterrupted())
2247  return failure();
2248 
2249  // The structured selection/loop's entry block does not have arguments.
2250  // If the function's header block is also part of the structured control
2251  // flow, we cannot just simply erase it because it may contain arguments
2252  // matching the function signature and used by the cloned blocks.
2253  if (isFnEntryBlock(block)) {
2254  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2255  << " to only contain a spirv.Branch op\n");
2256  // Still keep the function entry block for the potential block arguments,
2257  // but replace all ops inside with a branch to the merge block.
2258  block->clear();
2259  builder.setInsertionPointToEnd(block);
2260  builder.create<spirv::BranchOp>(location, mergeBlock);
2261  } else {
2262  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2263  block->erase();
2264  }
2265  }
2266 
2267  LLVM_DEBUG(logger.startLine()
2268  << "[cf] after structurizing construct with header block "
2269  << headerBlock << ":\n"
2270  << *op << "\n");
2271 
2272  return success();
2273 }
2274 
2275 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2276  LLVM_DEBUG({
2277  logger.startLine()
2278  << "//----- [phi] start wiring up block arguments -----//\n";
2279  logger.indent();
2280  });
2281 
2282  OpBuilder::InsertionGuard guard(opBuilder);
2283 
2284  for (const auto &info : blockPhiInfo) {
2285  Block *block = info.first.first;
2286  Block *target = info.first.second;
2287  const BlockPhiInfo &phiInfo = info.second;
2288  LLVM_DEBUG({
2289  logger.startLine() << "[phi] block " << block << "\n";
2290  logger.startLine() << "[phi] before creating block argument:\n";
2291  block->getParentOp()->print(logger.getOStream());
2292  logger.startLine() << "\n";
2293  });
2294 
2295  // Set insertion point to before this block's terminator early because we
2296  // may materialize ops via getValue() call.
2297  auto *op = block->getTerminator();
2298  opBuilder.setInsertionPoint(op);
2299 
2300  SmallVector<Value, 4> blockArgs;
2301  blockArgs.reserve(phiInfo.size());
2302  for (uint32_t valueId : phiInfo) {
2303  if (Value value = getValue(valueId)) {
2304  blockArgs.push_back(value);
2305  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2306  << " id = " << valueId << "\n");
2307  } else {
2308  return emitError(unknownLoc, "OpPhi references undefined value!");
2309  }
2310  }
2311 
2312  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2313  // Replace the previous branch op with a new one with block arguments.
2314  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2315  blockArgs);
2316  branchOp.erase();
2317  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2318  assert((branchCondOp.getTrueBlock() == target ||
2319  branchCondOp.getFalseBlock() == target) &&
2320  "expected target to be either the true or false target");
2321  if (target == branchCondOp.getTrueTarget())
2322  opBuilder.create<spirv::BranchConditionalOp>(
2323  branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2324  branchCondOp.getFalseBlockArguments(),
2325  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2326  branchCondOp.getFalseTarget());
2327  else
2328  opBuilder.create<spirv::BranchConditionalOp>(
2329  branchCondOp.getLoc(), branchCondOp.getCondition(),
2330  branchCondOp.getTrueBlockArguments(), blockArgs,
2331  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2332  branchCondOp.getFalseBlock());
2333 
2334  branchCondOp.erase();
2335  } else {
2336  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2337  }
2338 
2339  LLVM_DEBUG({
2340  logger.startLine() << "[phi] after creating block argument:\n";
2341  block->getParentOp()->print(logger.getOStream());
2342  logger.startLine() << "\n";
2343  });
2344  }
2345  blockPhiInfo.clear();
2346 
2347  LLVM_DEBUG({
2348  logger.unindent();
2349  logger.startLine()
2350  << "//--- [phi] completed wiring up block arguments ---//\n";
2351  });
2352  return success();
2353 }
2354 
2355 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2356  // Create a copy, so we can modify keys in the original.
2357  BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2358  for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2359  it != e; ++it) {
2360  auto &[block, mergeInfo] = *it;
2361 
2362  // Skip processing loop regions. For loop regions continueBlock is non-null.
2363  if (mergeInfo.continueBlock)
2364  continue;
2365 
2366  if (!block->mightHaveTerminator())
2367  continue;
2368 
2369  Operation *terminator = block->getTerminator();
2370  assert(terminator);
2371 
2372  if (!isa<spirv::BranchConditionalOp>(terminator))
2373  continue;
2374 
2375  // Check if the current header block is a merge block of another construct.
2376  bool splitHeaderMergeBlock = false;
2377  for (const auto &[_, mergeInfo] : blockMergeInfo) {
2378  if (mergeInfo.mergeBlock == block)
2379  splitHeaderMergeBlock = true;
2380  }
2381 
2382  // Do not split a block that only contains a conditional branch, unless it
2383  // is also a merge block of another construct - in that case we want to
2384  // split the block. We do not want two constructs to share header / merge
2385  // block.
2386  if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2387  Block *newBlock = block->splitBlock(terminator);
2388  OpBuilder builder(block, block->end());
2389  builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
2390 
2391  // After splitting we need to update the map to use the new block as a
2392  // header.
2393  blockMergeInfo.erase(block);
2394  blockMergeInfo.try_emplace(newBlock, mergeInfo);
2395  }
2396  }
2397 
2398  return success();
2399 }
2400 
2401 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2402  if (!options.enableControlFlowStructurization) {
2403  LLVM_DEBUG(
2404  {
2405  logger.startLine()
2406  << "//----- [cf] skip structurizing control flow -----//\n";
2407  logger.indent();
2408  });
2409  return success();
2410  }
2411 
2412  LLVM_DEBUG({
2413  logger.startLine()
2414  << "//----- [cf] start structurizing control flow -----//\n";
2415  logger.indent();
2416  });
2417 
2418  LLVM_DEBUG({
2419  logger.startLine() << "[cf] split conditional blocks\n";
2420  logger.startLine() << "\n";
2421  });
2422 
2423  if (failed(splitConditionalBlocks())) {
2424  return failure();
2425  }
2426 
2427  // TODO: This loop is non-deterministic. Iteration order may vary between runs
2428  // for the same shader as the key to the map is a pointer. See:
2429  // https://github.com/llvm/llvm-project/issues/128547
2430  while (!blockMergeInfo.empty()) {
2431  Block *headerBlock = blockMergeInfo.begin()->first;
2432  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2433 
2434  LLVM_DEBUG({
2435  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2436  headerBlock->print(logger.getOStream());
2437  logger.startLine() << "\n";
2438  });
2439 
2440  auto *mergeBlock = mergeInfo.mergeBlock;
2441  assert(mergeBlock && "merge block cannot be nullptr");
2442  if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2443  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2444  LLVM_DEBUG({
2445  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2446  mergeBlock->print(logger.getOStream());
2447  logger.startLine() << "\n";
2448  });
2449 
2450  auto *continueBlock = mergeInfo.continueBlock;
2451  LLVM_DEBUG(if (continueBlock) {
2452  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2453  continueBlock->print(logger.getOStream());
2454  logger.startLine() << "\n";
2455  });
2456  // Erase this case before calling into structurizer, who will update
2457  // blockMergeInfo.
2458  blockMergeInfo.erase(blockMergeInfo.begin());
2459  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2460  blockMergeInfo, headerBlock,
2461  mergeBlock, continueBlock
2462 #ifndef NDEBUG
2463  ,
2464  logger
2465 #endif
2466  );
2467  if (failed(structurizer.structurize()))
2468  return failure();
2469  }
2470 
2471  LLVM_DEBUG({
2472  logger.unindent();
2473  logger.startLine()
2474  << "//--- [cf] completed structurizing control flow ---//\n";
2475  });
2476  return success();
2477 }
2478 
2479 //===----------------------------------------------------------------------===//
2480 // Debug
2481 //===----------------------------------------------------------------------===//
2482 
2483 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2484  if (!debugLine)
2485  return unknownLoc;
2486 
2487  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2488  if (fileName.empty())
2489  fileName = "<unknown>";
2490  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2491  debugLine->column);
2492 }
2493 
2494 LogicalResult
2495 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2496  // According to SPIR-V spec:
2497  // "This location information applies to the instructions physically
2498  // following this instruction, up to the first occurrence of any of the
2499  // following: the next end of block, the next OpLine instruction, or the next
2500  // OpNoLine instruction."
2501  if (operands.size() != 3)
2502  return emitError(unknownLoc, "OpLine must have 3 operands");
2503  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2504  return success();
2505 }
2506 
2507 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2508 
2509 LogicalResult
2510 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2511  if (operands.size() < 2)
2512  return emitError(unknownLoc, "OpString needs at least 2 operands");
2513 
2514  if (!debugInfoMap.lookup(operands[0]).empty())
2515  return emitError(unknownLoc,
2516  "duplicate debug string found for result <id> ")
2517  << operands[0];
2518 
2519  unsigned wordIndex = 1;
2520  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2521  if (wordIndex != operands.size())
2522  return emitError(unknownLoc,
2523  "unexpected trailing words in OpString instruction");
2524 
2525  debugInfoMap[operands[0]] = debugString;
2526  return success();
2527 }
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
int64_t rows
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:310
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
void print(raw_ostream &os)
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:252
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
void push_back(Operation *op)
Definition: Block.h:149
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:161
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:695
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockListType::iterator iterator
Definition: Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:235
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:427
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:484
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:767
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:996
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
Definition: Deserializer.h:61
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.