MLIR  21.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context,
54  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55  module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56 #ifndef NDEBUG
57  ,
58  logger(llvm::dbgs())
59 #endif
60 {
61 }
62 
64  LLVM_DEBUG({
65  logger.resetIndent();
66  logger.startLine()
67  << "//+++---------- start deserialization ----------+++//\n";
68  });
69 
70  if (failed(processHeader()))
71  return failure();
72 
73  spirv::Opcode opcode = spirv::Opcode::OpNop;
74  ArrayRef<uint32_t> operands;
75  auto binarySize = binary.size();
76  while (curOffset < binarySize) {
77  // Slice the next instruction out and populate `opcode` and `operands`.
78  // Internally this also updates `curOffset`.
79  if (failed(sliceInstruction(opcode, operands)))
80  return failure();
81 
82  if (failed(processInstruction(opcode, operands)))
83  return failure();
84  }
85 
86  assert(curOffset == binarySize &&
87  "deserializer should never index beyond the binary end");
88 
89  for (auto &deferred : deferredInstructions) {
90  if (failed(processInstruction(deferred.first, deferred.second, false))) {
91  return failure();
92  }
93  }
94 
95  attachVCETriple();
96 
97  LLVM_DEBUG(logger.startLine()
98  << "//+++-------- completed deserialization --------+++//\n");
99  return success();
100 }
101 
103  return std::move(module);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Module structure
108 //===----------------------------------------------------------------------===//
109 
110 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111  OpBuilder builder(context);
112  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113  spirv::ModuleOp::build(builder, state);
114  return cast<spirv::ModuleOp>(Operation::create(state));
115 }
116 
117 LogicalResult spirv::Deserializer::processHeader() {
118  if (binary.size() < spirv::kHeaderWordCount)
119  return emitError(unknownLoc,
120  "SPIR-V binary module must have a 5-word header");
121 
122  if (binary[0] != spirv::kMagicNumber)
123  return emitError(unknownLoc, "incorrect magic number");
124 
125  // Version number bytes: 0 | major number | minor number | 0
126  uint32_t majorVersion = (binary[1] << 8) >> 24;
127  uint32_t minorVersion = (binary[1] << 16) >> 24;
128  if (majorVersion == 1) {
129  switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
131  case v: \
132  version = spirv::Version::V_1_##v; \
133  break
134 
135  MIN_VERSION_CASE(0);
136  MIN_VERSION_CASE(1);
137  MIN_VERSION_CASE(2);
138  MIN_VERSION_CASE(3);
139  MIN_VERSION_CASE(4);
140  MIN_VERSION_CASE(5);
141 #undef MIN_VERSION_CASE
142  default:
143  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
144  << minorVersion;
145  }
146  } else {
147  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
148  << majorVersion;
149  }
150 
151  // TODO: generator number, bound, schema
152  curOffset = spirv::kHeaderWordCount;
153  return success();
154 }
155 
156 LogicalResult
157 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
158  if (operands.size() != 1)
159  return emitError(unknownLoc, "OpCapability must have one parameter");
160 
161  auto cap = spirv::symbolizeCapability(operands[0]);
162  if (!cap)
163  return emitError(unknownLoc, "unknown capability: ") << operands[0];
164 
165  capabilities.insert(*cap);
166  return success();
167 }
168 
169 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
170  if (words.empty()) {
171  return emitError(
172  unknownLoc,
173  "OpExtension must have a literal string for the extension name");
174  }
175 
176  unsigned wordIndex = 0;
177  StringRef extName = decodeStringLiteral(words, wordIndex);
178  if (wordIndex != words.size())
179  return emitError(unknownLoc,
180  "unexpected trailing words in OpExtension instruction");
181  auto ext = spirv::symbolizeExtension(extName);
182  if (!ext)
183  return emitError(unknownLoc, "unknown extension: ") << extName;
184 
185  extensions.insert(*ext);
186  return success();
187 }
188 
189 LogicalResult
190 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
191  if (words.size() < 2) {
192  return emitError(unknownLoc,
193  "OpExtInstImport must have a result <id> and a literal "
194  "string for the extended instruction set name");
195  }
196 
197  unsigned wordIndex = 1;
198  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
199  if (wordIndex != words.size()) {
200  return emitError(unknownLoc,
201  "unexpected trailing words in OpExtInstImport");
202  }
203  return success();
204 }
205 
206 void spirv::Deserializer::attachVCETriple() {
207  (*module)->setAttr(
208  spirv::ModuleOp::getVCETripleAttrName(),
209  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
210  extensions.getArrayRef(), context));
211 }
212 
213 LogicalResult
214 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
215  if (operands.size() != 2)
216  return emitError(unknownLoc, "OpMemoryModel must have two operands");
217 
218  (*module)->setAttr(
219  module->getAddressingModelAttrName(),
220  opBuilder.getAttr<spirv::AddressingModelAttr>(
221  static_cast<spirv::AddressingModel>(operands.front())));
222 
223  (*module)->setAttr(module->getMemoryModelAttrName(),
224  opBuilder.getAttr<spirv::MemoryModelAttr>(
225  static_cast<spirv::MemoryModel>(operands.back())));
226 
227  return success();
228 }
229 
230 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
232  Location loc, OpBuilder &opBuilder,
234  StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
235  if (words.size() != 4) {
236  return emitError(loc, "OpDecoration with ")
237  << decorationName << "needs a cache control integer literal and a "
238  << cacheControlKind << " cache control literal";
239  }
240  unsigned cacheLevel = words[2];
241  auto cacheControlAttr = static_cast<EnumTy>(words[3]);
242  auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
244  if (auto attrList =
245  llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
246  llvm::append_range(attrs, attrList);
247  attrs.push_back(value);
248  decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
249  return success();
250 }
251 
252 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
253  // TODO: This function should also be auto-generated. For now, since only a
254  // few decorations are processed/handled in a meaningful manner, going with a
255  // manual implementation.
256  if (words.size() < 2) {
257  return emitError(
258  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
259  }
260  auto decorationName =
261  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
262  if (decorationName.empty()) {
263  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
264  }
265  auto symbol = getSymbolDecoration(decorationName);
266  switch (static_cast<spirv::Decoration>(words[1])) {
267  case spirv::Decoration::FPFastMathMode:
268  if (words.size() != 3) {
269  return emitError(unknownLoc, "OpDecorate with ")
270  << decorationName << " needs a single integer literal";
271  }
272  decorations[words[0]].set(
273  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
274  static_cast<FPFastMathMode>(words[2])));
275  break;
276  case spirv::Decoration::FPRoundingMode:
277  if (words.size() != 3) {
278  return emitError(unknownLoc, "OpDecorate with ")
279  << decorationName << " needs a single integer literal";
280  }
281  decorations[words[0]].set(
282  symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
283  static_cast<FPRoundingMode>(words[2])));
284  break;
285  case spirv::Decoration::DescriptorSet:
286  case spirv::Decoration::Binding:
287  if (words.size() != 3) {
288  return emitError(unknownLoc, "OpDecorate with ")
289  << decorationName << " needs a single integer literal";
290  }
291  decorations[words[0]].set(
292  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
293  break;
294  case spirv::Decoration::BuiltIn:
295  if (words.size() != 3) {
296  return emitError(unknownLoc, "OpDecorate with ")
297  << decorationName << " needs a single integer literal";
298  }
299  decorations[words[0]].set(
300  symbol, opBuilder.getStringAttr(
301  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
302  break;
303  case spirv::Decoration::ArrayStride:
304  if (words.size() != 3) {
305  return emitError(unknownLoc, "OpDecorate with ")
306  << decorationName << " needs a single integer literal";
307  }
308  typeDecorations[words[0]] = words[2];
309  break;
310  case spirv::Decoration::LinkageAttributes: {
311  if (words.size() < 4) {
312  return emitError(unknownLoc, "OpDecorate with ")
313  << decorationName
314  << " needs at least 1 string and 1 integer literal";
315  }
316  // LinkageAttributes has two parameters ["linkageName", linkageType]
317  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
318  // "linkageName" is a stringliteral encoded as uint32_t,
319  // hence the size of name is variable length which results in words.size()
320  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
321  // 3 + ceildiv(strlen(name), 4).
322  unsigned wordIndex = 2;
323  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
324  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
325  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
326  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
327  StringAttr::get(context, linkageName), linkageTypeAttr);
328  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
329  break;
330  }
331  case spirv::Decoration::Aliased:
332  case spirv::Decoration::AliasedPointer:
333  case spirv::Decoration::Block:
334  case spirv::Decoration::BufferBlock:
335  case spirv::Decoration::Flat:
336  case spirv::Decoration::NonReadable:
337  case spirv::Decoration::NonWritable:
338  case spirv::Decoration::NoPerspective:
339  case spirv::Decoration::NoSignedWrap:
340  case spirv::Decoration::NoUnsignedWrap:
341  case spirv::Decoration::RelaxedPrecision:
342  case spirv::Decoration::Restrict:
343  case spirv::Decoration::RestrictPointer:
344  case spirv::Decoration::NoContraction:
345  case spirv::Decoration::Constant:
346  if (words.size() != 2) {
347  return emitError(unknownLoc, "OpDecoration with ")
348  << decorationName << "needs a single target <id>";
349  }
350  // Block decoration does not affect spirv.struct type, but is still stored
351  // for verification.
352  // TODO: Update StructType to contain this information since
353  // it is needed for many validation rules.
354  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
355  break;
356  case spirv::Decoration::Location:
357  case spirv::Decoration::SpecId:
358  if (words.size() != 3) {
359  return emitError(unknownLoc, "OpDecoration with ")
360  << decorationName << "needs a single integer literal";
361  }
362  decorations[words[0]].set(
363  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
364  break;
365  case spirv::Decoration::CacheControlLoadINTEL: {
366  LogicalResult res = deserializeCacheControlDecoration<
367  CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
369  "load");
370  if (failed(res))
371  return res;
372  break;
373  }
374  case spirv::Decoration::CacheControlStoreINTEL: {
375  LogicalResult res = deserializeCacheControlDecoration<
376  CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
378  "store");
379  if (failed(res))
380  return res;
381  break;
382  }
383  default:
384  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
385  }
386  return success();
387 }
388 
389 LogicalResult
390 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
391  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
392  if (words.size() < 3) {
393  return emitError(unknownLoc,
394  "OpMemberDecorate must have at least 3 operands");
395  }
396 
397  auto decoration = static_cast<spirv::Decoration>(words[2]);
398  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
399  return emitError(unknownLoc,
400  " missing offset specification in OpMemberDecorate with "
401  "Offset decoration");
402  }
403  ArrayRef<uint32_t> decorationOperands;
404  if (words.size() > 3) {
405  decorationOperands = words.slice(3);
406  }
407  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
408  return success();
409 }
410 
411 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
412  if (words.size() < 3) {
413  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
414  }
415  unsigned wordIndex = 2;
416  auto name = decodeStringLiteral(words, wordIndex);
417  if (wordIndex != words.size()) {
418  return emitError(unknownLoc,
419  "unexpected trailing words in OpMemberName instruction");
420  }
421  memberNameMap[words[0]][words[1]] = name;
422  return success();
423 }
424 
425 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
426  uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
427  if (!decorations.contains(argID)) {
428  argAttrs[argIndex] = DictionaryAttr::get(context, {});
429  return success();
430  }
431 
432  spirv::DecorationAttr foundDecorationAttr;
433  for (NamedAttribute decAttr : decorations[argID]) {
434  for (auto decoration :
435  {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436  spirv::Decoration::AliasedPointer,
437  spirv::Decoration::RestrictPointer}) {
438 
439  if (decAttr.getName() !=
440  getSymbolDecoration(stringifyDecoration(decoration)))
441  continue;
442 
443  if (foundDecorationAttr)
444  return emitError(unknownLoc,
445  "more than one Aliased/Restrict decorations for "
446  "function argument with result <id> ")
447  << argID;
448 
449  foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
450  break;
451  }
452 
453  if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
454  spirv::Decoration::RelaxedPrecision))) {
455  // TODO: Current implementation supports only one decoration per function
456  // parameter so RelaxedPrecision cannot be applied at the same time as,
457  // for example, Aliased/Restrict/etc. This should be relaxed to allow any
458  // combination of decoration allowed by the spec to be supported.
459  if (foundDecorationAttr)
460  return emitError(unknownLoc, "already found a decoration for function "
461  "argument with result <id> ")
462  << argID;
463 
464  foundDecorationAttr = spirv::DecorationAttr::get(
465  context, spirv::Decoration::RelaxedPrecision);
466  }
467  }
468 
469  if (!foundDecorationAttr)
470  return emitError(unknownLoc, "unimplemented decoration support for "
471  "function argument with result <id> ")
472  << argID;
473 
474  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
475  foundDecorationAttr);
476  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
477  return success();
478 }
479 
480 LogicalResult
481 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
482  if (curFunction) {
483  return emitError(unknownLoc, "found function inside function");
484  }
485 
486  // Get the result type
487  if (operands.size() != 4) {
488  return emitError(unknownLoc, "OpFunction must have 4 parameters");
489  }
490  Type resultType = getType(operands[0]);
491  if (!resultType) {
492  return emitError(unknownLoc, "undefined result type from <id> ")
493  << operands[0];
494  }
495 
496  uint32_t fnID = operands[1];
497  if (funcMap.count(fnID)) {
498  return emitError(unknownLoc, "duplicate function definition/declaration");
499  }
500 
501  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
502  if (!fnControl) {
503  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
504  }
505 
506  Type fnType = getType(operands[3]);
507  if (!fnType || !isa<FunctionType>(fnType)) {
508  return emitError(unknownLoc, "unknown function type from <id> ")
509  << operands[3];
510  }
511  auto functionType = cast<FunctionType>(fnType);
512 
513  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
514  (functionType.getNumResults() == 1 &&
515  functionType.getResult(0) != resultType)) {
516  return emitError(unknownLoc, "mismatch in function type ")
517  << functionType << " and return type " << resultType << " specified";
518  }
519 
520  std::string fnName = getFunctionSymbol(fnID);
521  auto funcOp = opBuilder.create<spirv::FuncOp>(
522  unknownLoc, fnName, functionType, fnControl.value());
523  // Processing other function attributes.
524  if (decorations.count(fnID)) {
525  for (auto attr : decorations[fnID].getAttrs()) {
526  funcOp->setAttr(attr.getName(), attr.getValue());
527  }
528  }
529  curFunction = funcMap[fnID] = funcOp;
530  auto *entryBlock = funcOp.addEntryBlock();
531  LLVM_DEBUG({
532  logger.startLine()
533  << "//===-------------------------------------------===//\n";
534  logger.startLine() << "[fn] name: " << fnName << "\n";
535  logger.startLine() << "[fn] type: " << fnType << "\n";
536  logger.startLine() << "[fn] ID: " << fnID << "\n";
537  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
538  logger.indent();
539  });
540 
541  SmallVector<Attribute> argAttrs;
542  argAttrs.resize(functionType.getNumInputs());
543 
544  // Parse the op argument instructions
545  if (functionType.getNumInputs()) {
546  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547  auto argType = functionType.getInput(i);
548  spirv::Opcode opcode = spirv::Opcode::OpNop;
549  ArrayRef<uint32_t> operands;
550  if (failed(sliceInstruction(opcode, operands,
551  spirv::Opcode::OpFunctionParameter))) {
552  return failure();
553  }
554  if (opcode != spirv::Opcode::OpFunctionParameter) {
555  return emitError(
556  unknownLoc,
557  "missing OpFunctionParameter instruction for argument ")
558  << i;
559  }
560  if (operands.size() != 2) {
561  return emitError(
562  unknownLoc,
563  "expected result type and result <id> for OpFunctionParameter");
564  }
565  auto argDefinedType = getType(operands[0]);
566  if (!argDefinedType || argDefinedType != argType) {
567  return emitError(unknownLoc,
568  "mismatch in argument type between function type "
569  "definition ")
570  << functionType << " and argument type definition "
571  << argDefinedType << " at argument " << i;
572  }
573  if (getValue(operands[1])) {
574  return emitError(unknownLoc, "duplicate definition of result <id> ")
575  << operands[1];
576  }
577  if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
578  return failure();
579  }
580 
581  auto argValue = funcOp.getArgument(i);
582  valueMap[operands[1]] = argValue;
583  }
584  }
585 
586  if (llvm::any_of(argAttrs, [](Attribute attr) {
587  auto argAttr = cast<DictionaryAttr>(attr);
588  return !argAttr.empty();
589  }))
590  funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
591 
592  // entryBlock is needed to access the arguments, Once that is done, we can
593  // erase the block for functions with 'Import' LinkageAttributes, since these
594  // are essentially function declarations, so they have no body.
595  auto linkageAttr = funcOp.getLinkageAttributes();
596  auto hasImportLinkage =
597  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
598  spirv::LinkageType::Import);
599  if (hasImportLinkage)
600  funcOp.eraseBody();
601 
602  // RAII guard to reset the insertion point to the module's region after
603  // deserializing the body of this function.
604  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
605 
606  spirv::Opcode opcode = spirv::Opcode::OpNop;
607  ArrayRef<uint32_t> instOperands;
608 
609  // Special handling for the entry block. We need to make sure it starts with
610  // an OpLabel instruction. The entry block takes the same parameters as the
611  // function. All other blocks do not take any parameter. We have already
612  // created the entry block, here we need to register it to the correct label
613  // <id>.
614  if (failed(sliceInstruction(opcode, instOperands,
615  spirv::Opcode::OpFunctionEnd))) {
616  return failure();
617  }
618  if (opcode == spirv::Opcode::OpFunctionEnd) {
619  return processFunctionEnd(instOperands);
620  }
621  if (opcode != spirv::Opcode::OpLabel) {
622  return emitError(unknownLoc, "a basic block must start with OpLabel");
623  }
624  if (instOperands.size() != 1) {
625  return emitError(unknownLoc, "OpLabel should only have result <id>");
626  }
627  blockMap[instOperands[0]] = entryBlock;
628  if (failed(processLabel(instOperands))) {
629  return failure();
630  }
631 
632  // Then process all the other instructions in the function until we hit
633  // OpFunctionEnd.
634  while (succeeded(sliceInstruction(opcode, instOperands,
635  spirv::Opcode::OpFunctionEnd)) &&
636  opcode != spirv::Opcode::OpFunctionEnd) {
637  if (failed(processInstruction(opcode, instOperands))) {
638  return failure();
639  }
640  }
641  if (opcode != spirv::Opcode::OpFunctionEnd) {
642  return failure();
643  }
644 
645  return processFunctionEnd(instOperands);
646 }
647 
648 LogicalResult
649 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
650  // Process OpFunctionEnd.
651  if (!operands.empty()) {
652  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
653  }
654 
655  // Wire up block arguments from OpPhi instructions.
656  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
657  // ops.
658  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
659  return failure();
660  }
661 
662  curBlock = nullptr;
663  curFunction = std::nullopt;
664 
665  LLVM_DEBUG({
666  logger.unindent();
667  logger.startLine()
668  << "//===-------------------------------------------===//\n";
669  });
670  return success();
671 }
672 
673 std::optional<std::pair<Attribute, Type>>
674 spirv::Deserializer::getConstant(uint32_t id) {
675  auto constIt = constantMap.find(id);
676  if (constIt == constantMap.end())
677  return std::nullopt;
678  return constIt->getSecond();
679 }
680 
681 std::optional<spirv::SpecConstOperationMaterializationInfo>
682 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
683  auto constIt = specConstOperationMap.find(id);
684  if (constIt == specConstOperationMap.end())
685  return std::nullopt;
686  return constIt->getSecond();
687 }
688 
689 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
690  auto funcName = nameMap.lookup(id).str();
691  if (funcName.empty()) {
692  funcName = "spirv_fn_" + std::to_string(id);
693  }
694  return funcName;
695 }
696 
697 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
698  auto constName = nameMap.lookup(id).str();
699  if (constName.empty()) {
700  constName = "spirv_spec_const_" + std::to_string(id);
701  }
702  return constName;
703 }
704 
705 spirv::SpecConstantOp
706 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
707  TypedAttr defaultValue) {
708  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
709  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
710  defaultValue);
711  if (decorations.count(resultID)) {
712  for (auto attr : decorations[resultID].getAttrs())
713  op->setAttr(attr.getName(), attr.getValue());
714  }
715  specConstMap[resultID] = op;
716  return op;
717 }
718 
719 LogicalResult
720 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
721  unsigned wordIndex = 0;
722  if (operands.size() < 3) {
723  return emitError(
724  unknownLoc,
725  "OpVariable needs at least 3 operands, type, <id> and storage class");
726  }
727 
728  // Result Type.
729  auto type = getType(operands[wordIndex]);
730  if (!type) {
731  return emitError(unknownLoc, "unknown result type <id> : ")
732  << operands[wordIndex];
733  }
734  auto ptrType = dyn_cast<spirv::PointerType>(type);
735  if (!ptrType) {
736  return emitError(unknownLoc,
737  "expected a result type <id> to be a spirv.ptr, found : ")
738  << type;
739  }
740  wordIndex++;
741 
742  // Result <id>.
743  auto variableID = operands[wordIndex];
744  auto variableName = nameMap.lookup(variableID).str();
745  if (variableName.empty()) {
746  variableName = "spirv_var_" + std::to_string(variableID);
747  }
748  wordIndex++;
749 
750  // Storage class.
751  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
752  if (ptrType.getStorageClass() != storageClass) {
753  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
754  << type << " and that specified in OpVariable instruction : "
755  << stringifyStorageClass(storageClass);
756  }
757  wordIndex++;
758 
759  // Initializer.
760  FlatSymbolRefAttr initializer = nullptr;
761 
762  if (wordIndex < operands.size()) {
763  Operation *op = nullptr;
764 
765  if (auto initOp = getGlobalVariable(operands[wordIndex]))
766  op = initOp;
767  else if (auto initOp = getSpecConstant(operands[wordIndex]))
768  op = initOp;
769  else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
770  op = initOp;
771  else
772  return emitError(unknownLoc, "unknown <id> ")
773  << operands[wordIndex] << "used as initializer";
774 
775  initializer = SymbolRefAttr::get(op);
776  wordIndex++;
777  }
778  if (wordIndex != operands.size()) {
779  return emitError(unknownLoc,
780  "found more operands than expected when deserializing "
781  "OpVariable instruction, only ")
782  << wordIndex << " of " << operands.size() << " processed";
783  }
784  auto loc = createFileLineColLoc(opBuilder);
785  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
786  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
787  initializer);
788 
789  // Decorations.
790  if (decorations.count(variableID)) {
791  for (auto attr : decorations[variableID].getAttrs())
792  varOp->setAttr(attr.getName(), attr.getValue());
793  }
794  globalVariableMap[variableID] = varOp;
795  return success();
796 }
797 
798 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
799  auto constInfo = getConstant(id);
800  if (!constInfo) {
801  return nullptr;
802  }
803  return dyn_cast<IntegerAttr>(constInfo->first);
804 }
805 
806 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
807  if (operands.size() < 2) {
808  return emitError(unknownLoc, "OpName needs at least 2 operands");
809  }
810  if (!nameMap.lookup(operands[0]).empty()) {
811  return emitError(unknownLoc, "duplicate name found for result <id> ")
812  << operands[0];
813  }
814  unsigned wordIndex = 1;
815  StringRef name = decodeStringLiteral(operands, wordIndex);
816  if (wordIndex != operands.size()) {
817  return emitError(unknownLoc,
818  "unexpected trailing words in OpName instruction");
819  }
820  nameMap[operands[0]] = name;
821  return success();
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Type
826 //===----------------------------------------------------------------------===//
827 
828 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
829  ArrayRef<uint32_t> operands) {
830  if (operands.empty()) {
831  return emitError(unknownLoc, "type instruction with opcode ")
832  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
833  }
834 
835  /// TODO: Types might be forward declared in some instructions and need to be
836  /// handled appropriately.
837  if (typeMap.count(operands[0])) {
838  return emitError(unknownLoc, "duplicate definition for result <id> ")
839  << operands[0];
840  }
841 
842  switch (opcode) {
843  case spirv::Opcode::OpTypeVoid:
844  if (operands.size() != 1)
845  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
846  typeMap[operands[0]] = opBuilder.getNoneType();
847  break;
848  case spirv::Opcode::OpTypeBool:
849  if (operands.size() != 1)
850  return emitError(unknownLoc, "OpTypeBool must have no parameters");
851  typeMap[operands[0]] = opBuilder.getI1Type();
852  break;
853  case spirv::Opcode::OpTypeInt: {
854  if (operands.size() != 3)
855  return emitError(
856  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
857 
858  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
859  // to preserve or validate.
860  // 0 indicates unsigned, or no signedness semantics
861  // 1 indicates signed semantics."
862  //
863  // So we cannot differentiate signless and unsigned integers; always use
864  // signless semantics for such cases.
865  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
866  : IntegerType::SignednessSemantics::Signless;
867  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
868  } break;
869  case spirv::Opcode::OpTypeFloat: {
870  if (operands.size() != 2)
871  return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
872 
873  Type floatTy;
874  switch (operands[1]) {
875  case 16:
876  floatTy = opBuilder.getF16Type();
877  break;
878  case 32:
879  floatTy = opBuilder.getF32Type();
880  break;
881  case 64:
882  floatTy = opBuilder.getF64Type();
883  break;
884  default:
885  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
886  << operands[1];
887  }
888  typeMap[operands[0]] = floatTy;
889  } break;
890  case spirv::Opcode::OpTypeVector: {
891  if (operands.size() != 3) {
892  return emitError(
893  unknownLoc,
894  "OpTypeVector must have element type and count parameters");
895  }
896  Type elementTy = getType(operands[1]);
897  if (!elementTy) {
898  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
899  << operands[1];
900  }
901  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
902  } break;
903  case spirv::Opcode::OpTypePointer: {
904  return processOpTypePointer(operands);
905  } break;
906  case spirv::Opcode::OpTypeArray:
907  return processArrayType(operands);
908  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
909  return processCooperativeMatrixTypeKHR(operands);
910  case spirv::Opcode::OpTypeFunction:
911  return processFunctionType(operands);
912  case spirv::Opcode::OpTypeImage:
913  return processImageType(operands);
914  case spirv::Opcode::OpTypeSampledImage:
915  return processSampledImageType(operands);
916  case spirv::Opcode::OpTypeRuntimeArray:
917  return processRuntimeArrayType(operands);
918  case spirv::Opcode::OpTypeStruct:
919  return processStructType(operands);
920  case spirv::Opcode::OpTypeMatrix:
921  return processMatrixType(operands);
922  default:
923  return emitError(unknownLoc, "unhandled type instruction");
924  }
925  return success();
926 }
927 
928 LogicalResult
929 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
930  if (operands.size() != 3)
931  return emitError(unknownLoc, "OpTypePointer must have two parameters");
932 
933  auto pointeeType = getType(operands[2]);
934  if (!pointeeType)
935  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
936  << operands[2];
937 
938  uint32_t typePointerID = operands[0];
939  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
940  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
941 
942  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
943  deferredStructIt != std::end(deferredStructTypesInfos);) {
944  for (auto *unresolvedMemberIt =
945  std::begin(deferredStructIt->unresolvedMemberTypes);
946  unresolvedMemberIt !=
947  std::end(deferredStructIt->unresolvedMemberTypes);) {
948  if (unresolvedMemberIt->first == typePointerID) {
949  // The newly constructed pointer type can resolve one of the
950  // deferred struct type members; update the memberTypes list and
951  // clean the unresolvedMemberTypes list accordingly.
952  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
953  typeMap[typePointerID];
954  unresolvedMemberIt =
955  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
956  } else {
957  ++unresolvedMemberIt;
958  }
959  }
960 
961  if (deferredStructIt->unresolvedMemberTypes.empty()) {
962  // All deferred struct type members are now resolved, set the struct body.
963  auto structType = deferredStructIt->deferredStructType;
964 
965  assert(structType && "expected a spirv::StructType");
966  assert(structType.isIdentified() && "expected an indentified struct");
967 
968  if (failed(structType.trySetBody(
969  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
970  deferredStructIt->memberDecorationsInfo)))
971  return failure();
972 
973  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
974  } else {
975  ++deferredStructIt;
976  }
977  }
978 
979  return success();
980 }
981 
982 LogicalResult
983 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
984  if (operands.size() != 3) {
985  return emitError(unknownLoc,
986  "OpTypeArray must have element type and count parameters");
987  }
988 
989  Type elementTy = getType(operands[1]);
990  if (!elementTy) {
991  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
992  << operands[1];
993  }
994 
995  unsigned count = 0;
996  // TODO: The count can also come frome a specialization constant.
997  auto countInfo = getConstant(operands[2]);
998  if (!countInfo) {
999  return emitError(unknownLoc, "OpTypeArray count <id> ")
1000  << operands[2] << "can only come from normal constant right now";
1001  }
1002 
1003  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1004  count = intVal.getValue().getZExtValue();
1005  } else {
1006  return emitError(unknownLoc, "OpTypeArray count must come from a "
1007  "scalar integer constant instruction");
1008  }
1009 
1010  typeMap[operands[0]] = spirv::ArrayType::get(
1011  elementTy, count, typeDecorations.lookup(operands[0]));
1012  return success();
1013 }
1014 
1015 LogicalResult
1016 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1017  assert(!operands.empty() && "No operands for processing function type");
1018  if (operands.size() == 1) {
1019  return emitError(unknownLoc, "missing return type for OpTypeFunction");
1020  }
1021  auto returnType = getType(operands[1]);
1022  if (!returnType) {
1023  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1024  }
1025  SmallVector<Type, 1> argTypes;
1026  for (size_t i = 2, e = operands.size(); i < e; ++i) {
1027  auto ty = getType(operands[i]);
1028  if (!ty) {
1029  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1030  }
1031  argTypes.push_back(ty);
1032  }
1033  ArrayRef<Type> returnTypes;
1034  if (!isVoidType(returnType)) {
1035  returnTypes = llvm::ArrayRef(returnType);
1036  }
1037  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1038  return success();
1039 }
1040 
1041 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1042  ArrayRef<uint32_t> operands) {
1043  if (operands.size() != 6) {
1044  return emitError(unknownLoc,
1045  "OpTypeCooperativeMatrixKHR must have element type, "
1046  "scope, row and column parameters, and use");
1047  }
1048 
1049  Type elementTy = getType(operands[1]);
1050  if (!elementTy) {
1051  return emitError(unknownLoc,
1052  "OpTypeCooperativeMatrixKHR references undefined <id> ")
1053  << operands[1];
1054  }
1055 
1056  std::optional<spirv::Scope> scope =
1057  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1058  if (!scope) {
1059  return emitError(
1060  unknownLoc,
1061  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1062  << operands[2];
1063  }
1064 
1065  IntegerAttr rowsAttr = getConstantInt(operands[3]);
1066  IntegerAttr columnsAttr = getConstantInt(operands[4]);
1067  IntegerAttr useAttr = getConstantInt(operands[5]);
1068 
1069  if (!rowsAttr)
1070  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1071  "undefined constant <id> ")
1072  << operands[3];
1073 
1074  if (!columnsAttr)
1075  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1076  "references undefined constant <id> ")
1077  << operands[4];
1078 
1079  if (!useAttr)
1080  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1081  "undefined constant <id> ")
1082  << operands[5];
1083 
1084  unsigned rows = rowsAttr.getInt();
1085  unsigned columns = columnsAttr.getInt();
1086 
1087  std::optional<spirv::CooperativeMatrixUseKHR> use =
1088  spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1089  if (!use) {
1090  return emitError(
1091  unknownLoc,
1092  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1093  << operands[5];
1094  }
1095 
1096  typeMap[operands[0]] =
1097  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1098  return success();
1099 }
1100 
1101 LogicalResult
1102 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1103  if (operands.size() != 2) {
1104  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1105  }
1106  Type memberType = getType(operands[1]);
1107  if (!memberType) {
1108  return emitError(unknownLoc,
1109  "OpTypeRuntimeArray references undefined <id> ")
1110  << operands[1];
1111  }
1112  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1113  memberType, typeDecorations.lookup(operands[0]));
1114  return success();
1115 }
1116 
1117 LogicalResult
1118 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1119  // TODO: Find a way to handle identified structs when debug info is stripped.
1120 
1121  if (operands.empty()) {
1122  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1123  }
1124 
1125  if (operands.size() == 1) {
1126  // Handle empty struct.
1127  typeMap[operands[0]] =
1128  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1129  return success();
1130  }
1131 
1132  // First element is operand ID, second element is member index in the struct.
1133  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1134  SmallVector<Type, 4> memberTypes;
1135 
1136  for (auto op : llvm::drop_begin(operands, 1)) {
1137  Type memberType = getType(op);
1138  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1139 
1140  if (!memberType && !typeForwardPtr)
1141  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1142  << op;
1143 
1144  if (!memberType)
1145  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1146 
1147  memberTypes.push_back(memberType);
1148  }
1149 
1152  if (memberDecorationMap.count(operands[0])) {
1153  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1154  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1155  if (allMemberDecorations.count(memberIndex)) {
1156  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1157  // Check for offset.
1158  if (memberDecoration.first == spirv::Decoration::Offset) {
1159  // If offset info is empty, resize to the number of members;
1160  if (offsetInfo.empty()) {
1161  offsetInfo.resize(memberTypes.size());
1162  }
1163  offsetInfo[memberIndex] = memberDecoration.second[0];
1164  } else {
1165  if (!memberDecoration.second.empty()) {
1166  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1167  memberDecoration.first,
1168  memberDecoration.second[0]);
1169  } else {
1170  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1171  memberDecoration.first, 0);
1172  }
1173  }
1174  }
1175  }
1176  }
1177  }
1178 
1179  uint32_t structID = operands[0];
1180  std::string structIdentifier = nameMap.lookup(structID).str();
1181 
1182  if (structIdentifier.empty()) {
1183  assert(unresolvedMemberTypes.empty() &&
1184  "didn't expect unresolved member types");
1185  typeMap[structID] =
1186  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1187  } else {
1188  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1189  typeMap[structID] = structTy;
1190 
1191  if (!unresolvedMemberTypes.empty())
1192  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1193  memberTypes, offsetInfo,
1194  memberDecorationsInfo});
1195  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1196  memberDecorationsInfo)))
1197  return failure();
1198  }
1199 
1200  // TODO: Update StructType to have member name as attribute as
1201  // well.
1202  return success();
1203 }
1204 
1205 LogicalResult
1206 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1207  if (operands.size() != 3) {
1208  // Three operands are needed: result_id, column_type, and column_count
1209  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1210  " (result_id, column_type, and column_count)");
1211  }
1212  // Matrix columns must be of vector type
1213  Type elementTy = getType(operands[1]);
1214  if (!elementTy) {
1215  return emitError(unknownLoc,
1216  "OpTypeMatrix references undefined column type.")
1217  << operands[1];
1218  }
1219 
1220  uint32_t colsCount = operands[2];
1221  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1222  return success();
1223 }
1224 
1225 LogicalResult
1226 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1227  if (operands.size() != 2)
1228  return emitError(unknownLoc,
1229  "OpTypeForwardPointer instruction must have two operands");
1230 
1231  typeForwardPointerIDs.insert(operands[0]);
1232  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1233  // instruction that defines the actual type.
1234 
1235  return success();
1236 }
1237 
1238 LogicalResult
1239 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1240  // TODO: Add support for Access Qualifier.
1241  if (operands.size() != 8)
1242  return emitError(
1243  unknownLoc,
1244  "OpTypeImage with non-eight operands are not supported yet");
1245 
1246  Type elementTy = getType(operands[1]);
1247  if (!elementTy)
1248  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1249  << operands[1];
1250 
1251  auto dim = spirv::symbolizeDim(operands[2]);
1252  if (!dim)
1253  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1254  << operands[2];
1255 
1256  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1257  if (!depthInfo)
1258  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1259  << operands[3];
1260 
1261  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1262  if (!arrayedInfo)
1263  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1264  << operands[4];
1265 
1266  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1267  if (!samplingInfo)
1268  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1269 
1270  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1271  if (!samplerUseInfo)
1272  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1273  << operands[6];
1274 
1275  auto format = spirv::symbolizeImageFormat(operands[7]);
1276  if (!format)
1277  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1278  << operands[7];
1279 
1280  typeMap[operands[0]] = spirv::ImageType::get(
1281  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1282  samplingInfo.value(), samplerUseInfo.value(), format.value());
1283  return success();
1284 }
1285 
1286 LogicalResult
1287 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1288  if (operands.size() != 2)
1289  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1290 
1291  Type elementTy = getType(operands[1]);
1292  if (!elementTy)
1293  return emitError(unknownLoc,
1294  "OpTypeSampledImage references undefined <id>: ")
1295  << operands[1];
1296 
1297  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1298  return success();
1299 }
1300 
1301 //===----------------------------------------------------------------------===//
1302 // Constant
1303 //===----------------------------------------------------------------------===//
1304 
1305 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1306  bool isSpec) {
1307  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1308 
1309  if (operands.size() < 2) {
1310  return emitError(unknownLoc)
1311  << opname << " must have type <id> and result <id>";
1312  }
1313  if (operands.size() < 3) {
1314  return emitError(unknownLoc)
1315  << opname << " must have at least 1 more parameter";
1316  }
1317 
1318  Type resultType = getType(operands[0]);
1319  if (!resultType) {
1320  return emitError(unknownLoc, "undefined result type from <id> ")
1321  << operands[0];
1322  }
1323 
1324  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1325  if (bitwidth == 64) {
1326  if (operands.size() == 4) {
1327  return success();
1328  }
1329  return emitError(unknownLoc)
1330  << opname << " should have 2 parameters for 64-bit values";
1331  }
1332  if (bitwidth <= 32) {
1333  if (operands.size() == 3) {
1334  return success();
1335  }
1336 
1337  return emitError(unknownLoc)
1338  << opname
1339  << " should have 1 parameter for values with no more than 32 bits";
1340  }
1341  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1342  << bitwidth;
1343  };
1344 
1345  auto resultID = operands[1];
1346 
1347  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1348  auto bitwidth = intType.getWidth();
1349  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1350  return failure();
1351  }
1352 
1353  APInt value;
1354  if (bitwidth == 64) {
1355  // 64-bit integers are represented with two SPIR-V words. According to
1356  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1357  // literal’s low-order words appear first."
1358  struct DoubleWord {
1359  uint32_t word1;
1360  uint32_t word2;
1361  } words = {operands[2], operands[3]};
1362  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1363  } else if (bitwidth <= 32) {
1364  value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1365  /*implicitTrunc=*/true);
1366  }
1367 
1368  auto attr = opBuilder.getIntegerAttr(intType, value);
1369 
1370  if (isSpec) {
1371  createSpecConstant(unknownLoc, resultID, attr);
1372  } else {
1373  // For normal constants, we just record the attribute (and its type) for
1374  // later materialization at use sites.
1375  constantMap.try_emplace(resultID, attr, intType);
1376  }
1377 
1378  return success();
1379  }
1380 
1381  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1382  auto bitwidth = floatType.getWidth();
1383  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1384  return failure();
1385  }
1386 
1387  APFloat value(0.f);
1388  if (floatType.isF64()) {
1389  // Double values are represented with two SPIR-V words. According to
1390  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1391  // literal’s low-order words appear first."
1392  struct DoubleWord {
1393  uint32_t word1;
1394  uint32_t word2;
1395  } words = {operands[2], operands[3]};
1396  value = APFloat(llvm::bit_cast<double>(words));
1397  } else if (floatType.isF32()) {
1398  value = APFloat(llvm::bit_cast<float>(operands[2]));
1399  } else if (floatType.isF16()) {
1400  APInt data(16, operands[2]);
1401  value = APFloat(APFloat::IEEEhalf(), data);
1402  }
1403 
1404  auto attr = opBuilder.getFloatAttr(floatType, value);
1405  if (isSpec) {
1406  createSpecConstant(unknownLoc, resultID, attr);
1407  } else {
1408  // For normal constants, we just record the attribute (and its type) for
1409  // later materialization at use sites.
1410  constantMap.try_emplace(resultID, attr, floatType);
1411  }
1412 
1413  return success();
1414  }
1415 
1416  return emitError(unknownLoc, "OpConstant can only generate values of "
1417  "scalar integer or floating-point type");
1418 }
1419 
1420 LogicalResult spirv::Deserializer::processConstantBool(
1421  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1422  if (operands.size() != 2) {
1423  return emitError(unknownLoc, "Op")
1424  << (isSpec ? "Spec" : "") << "Constant"
1425  << (isTrue ? "True" : "False")
1426  << " must have type <id> and result <id>";
1427  }
1428 
1429  auto attr = opBuilder.getBoolAttr(isTrue);
1430  auto resultID = operands[1];
1431  if (isSpec) {
1432  createSpecConstant(unknownLoc, resultID, attr);
1433  } else {
1434  // For normal constants, we just record the attribute (and its type) for
1435  // later materialization at use sites.
1436  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1437  }
1438 
1439  return success();
1440 }
1441 
1442 LogicalResult
1443 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1444  if (operands.size() < 2) {
1445  return emitError(unknownLoc,
1446  "OpConstantComposite must have type <id> and result <id>");
1447  }
1448  if (operands.size() < 3) {
1449  return emitError(unknownLoc,
1450  "OpConstantComposite must have at least 1 parameter");
1451  }
1452 
1453  Type resultType = getType(operands[0]);
1454  if (!resultType) {
1455  return emitError(unknownLoc, "undefined result type from <id> ")
1456  << operands[0];
1457  }
1458 
1459  SmallVector<Attribute, 4> elements;
1460  elements.reserve(operands.size() - 2);
1461  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1462  auto elementInfo = getConstant(operands[i]);
1463  if (!elementInfo) {
1464  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1465  << operands[i] << " must come from a normal constant";
1466  }
1467  elements.push_back(elementInfo->first);
1468  }
1469 
1470  auto resultID = operands[1];
1471  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1472  auto attr = DenseElementsAttr::get(shapedType, elements);
1473  // For normal constants, we just record the attribute (and its type) for
1474  // later materialization at use sites.
1475  constantMap.try_emplace(resultID, attr, shapedType);
1476  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1477  auto attr = opBuilder.getArrayAttr(elements);
1478  constantMap.try_emplace(resultID, attr, resultType);
1479  } else {
1480  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1481  << resultType;
1482  }
1483 
1484  return success();
1485 }
1486 
1487 LogicalResult
1488 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1489  if (operands.size() < 2) {
1490  return emitError(unknownLoc,
1491  "OpConstantComposite must have type <id> and result <id>");
1492  }
1493  if (operands.size() < 3) {
1494  return emitError(unknownLoc,
1495  "OpConstantComposite must have at least 1 parameter");
1496  }
1497 
1498  Type resultType = getType(operands[0]);
1499  if (!resultType) {
1500  return emitError(unknownLoc, "undefined result type from <id> ")
1501  << operands[0];
1502  }
1503 
1504  auto resultID = operands[1];
1505  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1506 
1507  SmallVector<Attribute, 4> elements;
1508  elements.reserve(operands.size() - 2);
1509  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1510  auto elementInfo = getSpecConstant(operands[i]);
1511  elements.push_back(SymbolRefAttr::get(elementInfo));
1512  }
1513 
1514  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1515  unknownLoc, TypeAttr::get(resultType), symName,
1516  opBuilder.getArrayAttr(elements));
1517  specConstCompositeMap[resultID] = op;
1518 
1519  return success();
1520 }
1521 
1522 LogicalResult
1523 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1524  if (operands.size() < 3)
1525  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1526  "result <id>, and operand opcode");
1527 
1528  uint32_t resultTypeID = operands[0];
1529 
1530  if (!getType(resultTypeID))
1531  return emitError(unknownLoc, "undefined result type from <id> ")
1532  << resultTypeID;
1533 
1534  uint32_t resultID = operands[1];
1535  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1536  auto emplaceResult = specConstOperationMap.try_emplace(
1537  resultID,
1538  SpecConstOperationMaterializationInfo{
1539  enclosedOpcode, resultTypeID,
1540  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1541 
1542  if (!emplaceResult.second)
1543  return emitError(unknownLoc, "value with <id>: ")
1544  << resultID << " is probably defined before.";
1545 
1546  return success();
1547 }
1548 
1549 Value spirv::Deserializer::materializeSpecConstantOperation(
1550  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1551  ArrayRef<uint32_t> enclosedOpOperands) {
1552 
1553  Type resultType = getType(resultTypeID);
1554 
1555  // Instructions wrapped by OpSpecConstantOp need an ID for their
1556  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1557  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1558  // ID in that map is assigned to the result of the enclosed instruction. Note
1559  // that there is no need to update this fake ID since we only need to
1560  // reference the created Value for the enclosed op from the spv::YieldOp
1561  // created later in this method (both of which are the only values in their
1562  // region: the SpecConstantOperation's region). If we encounter another
1563  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1564  // previous Value assigned to it isn't visible in the current scope anyway.
1565  DenseMap<uint32_t, Value> newValueMap;
1566  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1567  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1568 
1569  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1570  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1571  enclosedOpResultTypeAndOperands.push_back(fakeID);
1572  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1573  enclosedOpOperands.end());
1574 
1575  // Process enclosed instruction before creating the enclosing
1576  // specConstantOperation (and its region). This way, references to constants,
1577  // global variables, and spec constants will be materialized outside the new
1578  // op's region. For more info, see Deserializer::getValue's implementation.
1579  if (failed(
1580  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1581  return Value();
1582 
1583  // Since the enclosed op is emitted in the current block, split it in a
1584  // separate new block.
1585  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1586 
1587  auto loc = createFileLineColLoc(opBuilder);
1588  auto specConstOperationOp =
1589  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1590 
1591  Region &body = specConstOperationOp.getBody();
1592  // Move the new block into SpecConstantOperation's body.
1593  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1594  Region::iterator(enclosedBlock));
1595  Block &block = body.back();
1596 
1597  // RAII guard to reset the insertion point to the module's region after
1598  // deserializing the body of the specConstantOperation.
1599  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1600  opBuilder.setInsertionPointToEnd(&block);
1601 
1602  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1603  return specConstOperationOp.getResult();
1604 }
1605 
1606 LogicalResult
1607 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1608  if (operands.size() != 2) {
1609  return emitError(unknownLoc,
1610  "OpConstantNull must have type <id> and result <id>");
1611  }
1612 
1613  Type resultType = getType(operands[0]);
1614  if (!resultType) {
1615  return emitError(unknownLoc, "undefined result type from <id> ")
1616  << operands[0];
1617  }
1618 
1619  auto resultID = operands[1];
1620  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1621  auto attr = opBuilder.getZeroAttr(resultType);
1622  // For normal constants, we just record the attribute (and its type) for
1623  // later materialization at use sites.
1624  constantMap.try_emplace(resultID, attr, resultType);
1625  return success();
1626  }
1627 
1628  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1629  << resultType;
1630 }
1631 
1632 //===----------------------------------------------------------------------===//
1633 // Control flow
1634 //===----------------------------------------------------------------------===//
1635 
1636 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1637  if (auto *block = getBlock(id)) {
1638  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1639  << " @ " << block << "\n");
1640  return block;
1641  }
1642 
1643  // We don't know where this block will be placed finally (in a
1644  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1645  // function for now and sort out the proper place later.
1646  auto *block = curFunction->addBlock();
1647  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1648  << " @ " << block << "\n");
1649  return blockMap[id] = block;
1650 }
1651 
1652 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1653  if (!curBlock) {
1654  return emitError(unknownLoc, "OpBranch must appear inside a block");
1655  }
1656 
1657  if (operands.size() != 1) {
1658  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1659  }
1660 
1661  auto *target = getOrCreateBlock(operands[0]);
1662  auto loc = createFileLineColLoc(opBuilder);
1663  // The preceding instruction for the OpBranch instruction could be an
1664  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1665  // the same OpLine information.
1666  opBuilder.create<spirv::BranchOp>(loc, target);
1667 
1668  clearDebugLine();
1669  return success();
1670 }
1671 
1672 LogicalResult
1673 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1674  if (!curBlock) {
1675  return emitError(unknownLoc,
1676  "OpBranchConditional must appear inside a block");
1677  }
1678 
1679  if (operands.size() != 3 && operands.size() != 5) {
1680  return emitError(unknownLoc,
1681  "OpBranchConditional must have condition, true label, "
1682  "false label, and optionally two branch weights");
1683  }
1684 
1685  auto condition = getValue(operands[0]);
1686  auto *trueBlock = getOrCreateBlock(operands[1]);
1687  auto *falseBlock = getOrCreateBlock(operands[2]);
1688 
1689  std::optional<std::pair<uint32_t, uint32_t>> weights;
1690  if (operands.size() == 5) {
1691  weights = std::make_pair(operands[3], operands[4]);
1692  }
1693  // The preceding instruction for the OpBranchConditional instruction could be
1694  // an OpSelectionMerge instruction, in this case they will have the same
1695  // OpLine information.
1696  auto loc = createFileLineColLoc(opBuilder);
1697  opBuilder.create<spirv::BranchConditionalOp>(
1698  loc, condition, trueBlock,
1699  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1700  /*falseArguments=*/ArrayRef<Value>(), weights);
1701 
1702  clearDebugLine();
1703  return success();
1704 }
1705 
1706 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1707  if (!curFunction) {
1708  return emitError(unknownLoc, "OpLabel must appear inside a function");
1709  }
1710 
1711  if (operands.size() != 1) {
1712  return emitError(unknownLoc, "OpLabel should only have result <id>");
1713  }
1714 
1715  auto labelID = operands[0];
1716  // We may have forward declared this block.
1717  auto *block = getOrCreateBlock(labelID);
1718  LLVM_DEBUG(logger.startLine()
1719  << "[block] populating block " << block << "\n");
1720  // If we have seen this block, make sure it was just a forward declaration.
1721  assert(block->empty() && "re-deserialize the same block!");
1722 
1723  opBuilder.setInsertionPointToStart(block);
1724  blockMap[labelID] = curBlock = block;
1725 
1726  return success();
1727 }
1728 
1729 LogicalResult
1730 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1731  if (!curBlock) {
1732  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1733  }
1734 
1735  if (operands.size() < 2) {
1736  return emitError(
1737  unknownLoc,
1738  "OpSelectionMerge must specify merge target and selection control");
1739  }
1740 
1741  auto *mergeBlock = getOrCreateBlock(operands[0]);
1742  auto loc = createFileLineColLoc(opBuilder);
1743  auto selectionControl = operands[1];
1744 
1745  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1746  .second) {
1747  return emitError(
1748  unknownLoc,
1749  "a block cannot have more than one OpSelectionMerge instruction");
1750  }
1751 
1752  return success();
1753 }
1754 
1755 LogicalResult
1756 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1757  if (!curBlock) {
1758  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1759  }
1760 
1761  if (operands.size() < 3) {
1762  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1763  "continue target and loop control");
1764  }
1765 
1766  auto *mergeBlock = getOrCreateBlock(operands[0]);
1767  auto *continueBlock = getOrCreateBlock(operands[1]);
1768  auto loc = createFileLineColLoc(opBuilder);
1769  uint32_t loopControl = operands[2];
1770 
1771  if (!blockMergeInfo
1772  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1773  .second) {
1774  return emitError(
1775  unknownLoc,
1776  "a block cannot have more than one OpLoopMerge instruction");
1777  }
1778 
1779  return success();
1780 }
1781 
1782 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1783  if (!curBlock) {
1784  return emitError(unknownLoc, "OpPhi must appear in a block");
1785  }
1786 
1787  if (operands.size() < 4) {
1788  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1789  "and variable-parent pairs");
1790  }
1791 
1792  // Create a block argument for this OpPhi instruction.
1793  Type blockArgType = getType(operands[0]);
1794  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1795  valueMap[operands[1]] = blockArg;
1796  LLVM_DEBUG(logger.startLine()
1797  << "[phi] created block argument " << blockArg
1798  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1799 
1800  // For each (value, predecessor) pair, insert the value to the predecessor's
1801  // blockPhiInfo entry so later we can fix the block argument there.
1802  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1803  uint32_t value = operands[i];
1804  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1805  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1806  blockPhiInfo[predecessorTargetPair].push_back(value);
1807  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1808  << " with arg id = " << value << "\n");
1809  }
1810 
1811  return success();
1812 }
1813 
1814 namespace {
1815 /// A class for putting all blocks in a structured selection/loop in a
1816 /// spirv.mlir.selection/spirv.mlir.loop op.
1817 class ControlFlowStructurizer {
1818 public:
1819 #ifndef NDEBUG
1820  ControlFlowStructurizer(Location loc, uint32_t control,
1821  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1822  Block *merge, Block *cont,
1823  llvm::ScopedPrinter &logger)
1824  : location(loc), control(control), blockMergeInfo(mergeInfo),
1825  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1826  logger(logger) {}
1827 #else
1828  ControlFlowStructurizer(Location loc, uint32_t control,
1829  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1830  Block *merge, Block *cont)
1831  : location(loc), control(control), blockMergeInfo(mergeInfo),
1832  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1833 #endif
1834 
1835  /// Structurizes the loop at the given `headerBlock`.
1836  ///
1837  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1838  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1839  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1840  /// method will also update `mergeInfo` by remapping all blocks inside to the
1841  /// newly cloned ones inside structured control flow op's regions.
1842  LogicalResult structurize();
1843 
1844 private:
1845  /// Creates a new spirv.mlir.selection op at the beginning of the
1846  /// `mergeBlock`.
1847  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1848 
1849  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1850  spirv::LoopOp createLoopOp(uint32_t loopControl);
1851 
1852  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1853  void collectBlocksInConstruct();
1854 
1855  Location location;
1856  uint32_t control;
1857 
1858  spirv::BlockMergeInfoMap &blockMergeInfo;
1859 
1860  Block *headerBlock;
1861  Block *mergeBlock;
1862  Block *continueBlock; // nullptr for spirv.mlir.selection
1863 
1864  SetVector<Block *> constructBlocks;
1865 
1866 #ifndef NDEBUG
1867  /// A logger used to emit information during the deserialzation process.
1868  llvm::ScopedPrinter &logger;
1869 #endif
1870 };
1871 } // namespace
1872 
1873 spirv::SelectionOp
1874 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1875  // Create a builder and set the insertion point to the beginning of the
1876  // merge block so that the newly created SelectionOp will be inserted there.
1877  OpBuilder builder(&mergeBlock->front());
1878 
1879  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1880  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1881  selectionOp.addMergeBlock(builder);
1882 
1883  return selectionOp;
1884 }
1885 
1886 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1887  // Create a builder and set the insertion point to the beginning of the
1888  // merge block so that the newly created LoopOp will be inserted there.
1889  OpBuilder builder(&mergeBlock->front());
1890 
1891  auto control = static_cast<spirv::LoopControl>(loopControl);
1892  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1893  loopOp.addEntryAndMergeBlock(builder);
1894 
1895  return loopOp;
1896 }
1897 
1898 void ControlFlowStructurizer::collectBlocksInConstruct() {
1899  assert(constructBlocks.empty() && "expected empty constructBlocks");
1900 
1901  // Put the header block in the work list first.
1902  constructBlocks.insert(headerBlock);
1903 
1904  // For each item in the work list, add its successors excluding the merge
1905  // block.
1906  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1907  for (auto *successor : constructBlocks[i]->getSuccessors())
1908  if (successor != mergeBlock)
1909  constructBlocks.insert(successor);
1910  }
1911 }
1912 
1913 LogicalResult ControlFlowStructurizer::structurize() {
1914  Operation *op = nullptr;
1915  bool isLoop = continueBlock != nullptr;
1916  if (isLoop) {
1917  if (auto loopOp = createLoopOp(control))
1918  op = loopOp.getOperation();
1919  } else {
1920  if (auto selectionOp = createSelectionOp(control))
1921  op = selectionOp.getOperation();
1922  }
1923  if (!op)
1924  return failure();
1925  Region &body = op->getRegion(0);
1926 
1927  IRMapping mapper;
1928  // All references to the old merge block should be directed to the
1929  // selection/loop merge block in the SelectionOp/LoopOp's region.
1930  mapper.map(mergeBlock, &body.back());
1931 
1932  collectBlocksInConstruct();
1933 
1934  // We've identified all blocks belonging to the selection/loop's region. Now
1935  // need to "move" them into the selection/loop. Instead of really moving the
1936  // blocks, in the following we copy them and remap all values and branches.
1937  // This is because:
1938  // * Inserting a block into a region requires the block not in any region
1939  // before. But selections/loops can nest so we can create selection/loop ops
1940  // in a nested manner, which means some blocks may already be in a
1941  // selection/loop region when to be moved again.
1942  // * It's much trickier to fix up the branches into and out of the loop's
1943  // region: we need to treat not-moved blocks and moved blocks differently:
1944  // Not-moved blocks jumping to the loop header block need to jump to the
1945  // merge point containing the new loop op but not the loop continue block's
1946  // back edge. Moved blocks jumping out of the loop need to jump to the
1947  // merge block inside the loop region but not other not-moved blocks.
1948  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1949  // logic.
1950 
1951  // Create a corresponding block in the SelectionOp/LoopOp's region for each
1952  // block in this loop construct.
1953  OpBuilder builder(body);
1954  for (auto *block : constructBlocks) {
1955  // Create a block and insert it before the selection/loop merge block in the
1956  // SelectionOp/LoopOp's region.
1957  auto *newBlock = builder.createBlock(&body.back());
1958  mapper.map(block, newBlock);
1959  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1960  << " from block " << block << "\n");
1961  if (!isFnEntryBlock(block)) {
1962  for (BlockArgument blockArg : block->getArguments()) {
1963  auto newArg =
1964  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1965  mapper.map(blockArg, newArg);
1966  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1967  << blockArg << " to " << newArg << "\n");
1968  }
1969  } else {
1970  LLVM_DEBUG(logger.startLine()
1971  << "[cf] block " << block << " is a function entry block\n");
1972  }
1973 
1974  for (auto &op : *block)
1975  newBlock->push_back(op.clone(mapper));
1976  }
1977 
1978  // Go through all ops and remap the operands.
1979  auto remapOperands = [&](Operation *op) {
1980  for (auto &operand : op->getOpOperands())
1981  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1982  operand.set(mappedOp);
1983  for (auto &succOp : op->getBlockOperands())
1984  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1985  succOp.set(mappedOp);
1986  };
1987  for (auto &block : body)
1988  block.walk(remapOperands);
1989 
1990  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1991  // the selection/loop construct into its region. Next we need to fix the
1992  // connections between this new SelectionOp/LoopOp with existing blocks.
1993 
1994  // All existing incoming branches should go to the merge block, where the
1995  // SelectionOp/LoopOp resides right now.
1996  headerBlock->replaceAllUsesWith(mergeBlock);
1997 
1998  LLVM_DEBUG({
1999  logger.startLine() << "[cf] after cloning and fixing references:\n";
2000  headerBlock->getParentOp()->print(logger.getOStream());
2001  logger.startLine() << "\n";
2002  });
2003 
2004  if (isLoop) {
2005  if (!mergeBlock->args_empty()) {
2006  return mergeBlock->getParentOp()->emitError(
2007  "OpPhi in loop merge block unsupported");
2008  }
2009 
2010  // The loop header block may have block arguments. Since now we place the
2011  // loop op inside the old merge block, we need to make sure the old merge
2012  // block has the same block argument list.
2013  for (BlockArgument blockArg : headerBlock->getArguments())
2014  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2015 
2016  // If the loop header block has block arguments, make sure the spirv.Branch
2017  // op matches.
2018  SmallVector<Value, 4> blockArgs;
2019  if (!headerBlock->args_empty())
2020  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2021 
2022  // The loop entry block should have a unconditional branch jumping to the
2023  // loop header block.
2024  builder.setInsertionPointToEnd(&body.front());
2025  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
2026  ArrayRef<Value>(blockArgs));
2027  }
2028 
2029  // Values defined inside the selection region that need to be yielded outside
2030  // the region.
2031  SmallVector<Value> valuesToYield;
2032  // Outside uses of values that were sunk into the selection region. Those uses
2033  // will be replaced with values returned by the SelectionOp.
2034  SmallVector<Value> outsideUses;
2035 
2036  // Move block arguments of the original block (`mergeBlock`) into the merge
2037  // block inside the selection (`body.back()`). Values produced by block
2038  // arguments will be yielded by the selection region. We do not update uses or
2039  // erase original block arguments yet. It will be done later in the code.
2040  //
2041  // Code below is not executed for loops as it would interfere with the logic
2042  // above. Currently block arguments in the merge block are not supported, but
2043  // instead, the code above copies those arguments from the header block into
2044  // the merge block. As such, running the code would yield those copied
2045  // arguments that is most likely not a desired behaviour. This may need to be
2046  // revisited in the future.
2047  if (!isLoop)
2048  for (BlockArgument blockArg : mergeBlock->getArguments()) {
2049  // Create new block arguments in the last block ("merge block") of the
2050  // selection region. We create one argument for each argument in
2051  // `mergeBlock`. This new value will need to be yielded, and the original
2052  // value replaced, so add them to appropriate vectors.
2053  body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2054  valuesToYield.push_back(body.back().getArguments().back());
2055  outsideUses.push_back(blockArg);
2056  }
2057 
2058  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2059  // cleaned up.
2060  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2061  // First we need to drop all operands' references inside all blocks. This is
2062  // needed because we can have blocks referencing SSA values from one another.
2063  for (auto *block : constructBlocks)
2064  block->dropAllReferences();
2065 
2066  // All internal uses should be removed from original blocks by now, so
2067  // whatever is left is an outside use and will need to be yielded from
2068  // the newly created selection / loop region.
2069  for (Block *block : constructBlocks) {
2070  for (Operation &op : *block) {
2071  if (!op.use_empty())
2072  for (Value result : op.getResults()) {
2073  valuesToYield.push_back(mapper.lookupOrNull(result));
2074  outsideUses.push_back(result);
2075  }
2076  }
2077  for (BlockArgument &arg : block->getArguments()) {
2078  if (!arg.use_empty()) {
2079  valuesToYield.push_back(mapper.lookupOrNull(arg));
2080  outsideUses.push_back(arg);
2081  }
2082  }
2083  }
2084 
2085  assert(valuesToYield.size() == outsideUses.size());
2086 
2087  // If we need to yield any values from the selection / loop region we will
2088  // take care of it here.
2089  if (!valuesToYield.empty()) {
2090  LLVM_DEBUG(logger.startLine()
2091  << "[cf] yielding values from the selection / loop region\n");
2092 
2093  // Update `mlir.merge` with values to be yield.
2094  auto mergeOps = body.back().getOps<spirv::MergeOp>();
2095  Operation *merge = llvm::getSingleElement(mergeOps);
2096  assert(merge);
2097  merge->setOperands(valuesToYield);
2098 
2099  // MLIR does not allow changing the number of results of an operation, so
2100  // we create a new SelectionOp / LoopOp with required list of results and
2101  // move the region from the initial SelectionOp / LoopOp. The initial
2102  // operation is then removed. Since we move the region to the new op all
2103  // links between blocks and remapping we have previously done should be
2104  // preserved.
2105  builder.setInsertionPoint(&mergeBlock->front());
2106 
2107  Operation *newOp = nullptr;
2108 
2109  if (isLoop)
2110  newOp = builder.create<spirv::LoopOp>(
2111  location, TypeRange(ValueRange(outsideUses)),
2112  static_cast<spirv::LoopControl>(control));
2113  else
2114  newOp = builder.create<spirv::SelectionOp>(
2115  location, TypeRange(ValueRange(outsideUses)),
2116  static_cast<spirv::SelectionControl>(control));
2117 
2118  newOp->getRegion(0).takeBody(body);
2119 
2120  // Remove initial op and swap the pointer to the newly created one.
2121  op->erase();
2122  op = newOp;
2123 
2124  // Update all outside uses to use results of the SelectionOp / LoopOp and
2125  // remove block arguments from the original merge block.
2126  for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2127  outsideUses[i].replaceAllUsesWith(op->getResult(i));
2128 
2129  // We do not support block arguments in loop merge block. Also running this
2130  // function with loop would break some of the loop specific code above
2131  // dealing with block arguments.
2132  if (!isLoop)
2133  mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2134  }
2135 
2136  // Check that whether some op in the to-be-erased blocks still has uses. Those
2137  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2138  // region. We cannot handle such cases given that once a value is sinked into
2139  // the SelectionOp/LoopOp's region, there is no escape for it.
2140  for (auto *block : constructBlocks) {
2141  for (Operation &op : *block)
2142  if (!op.use_empty())
2143  return op.emitOpError("failed control flow structurization: value has "
2144  "uses outside of the "
2145  "enclosing selection/loop construct");
2146  for (BlockArgument &arg : block->getArguments())
2147  if (!arg.use_empty())
2148  return emitError(arg.getLoc(), "failed control flow structurization: "
2149  "block argument has uses outside of the "
2150  "enclosing selection/loop construct");
2151  }
2152 
2153  // Then erase all old blocks.
2154  for (auto *block : constructBlocks) {
2155  // We've cloned all blocks belonging to this construct into the structured
2156  // control flow op's region. Among these blocks, some may compose another
2157  // selection/loop. If so, they will be recorded within blockMergeInfo.
2158  // We need to update the pointers there to the newly remapped ones so we can
2159  // continue structurizing them later.
2160  //
2161  // We need to walk each block as constructBlocks do not include blocks
2162  // internal to ops already structured within those blocks. It is not
2163  // fully clear to me why the mergeInfo of blocks (yet to be structured)
2164  // inside already structured selections/loops get invalidated and needs
2165  // updating, however the following example code can cause a crash (depending
2166  // on the structuring order), when the most inner selection is being
2167  // structured after the outer selection and loop have been already
2168  // structured:
2169  //
2170  // spirv.mlir.for {
2171  // // ...
2172  // spirv.mlir.selection {
2173  // // ..
2174  // // A selection region that hasn't been yet structured!
2175  // // ..
2176  // }
2177  // // ...
2178  // }
2179  //
2180  // If the loop gets structured after the outer selection, but before the
2181  // inner selection. Moving the already structured selection inside the loop
2182  // will invalidate the mergeInfo of the region that is not yet structured.
2183  // Just going over constructBlocks will not check and updated header blocks
2184  // inside the already structured selection region. Walking block fixes that.
2185  //
2186  // TODO: If structuring was done in a fixed order starting with inner
2187  // most constructs this most likely not be an issue and the whole code
2188  // section could be removed. However, with the current non-deterministic
2189  // order this is not possible.
2190  //
2191  // TODO: The asserts in the following assumes input SPIR-V blob forms
2192  // correctly nested selection/loop constructs. We should relax this and
2193  // support error cases better.
2194  auto updateMergeInfo = [&](Block *block) -> WalkResult {
2195  auto it = blockMergeInfo.find(block);
2196  if (it != blockMergeInfo.end()) {
2197  // Use the original location for nested selection/loop ops.
2198  Location loc = it->second.loc;
2199 
2200  Block *newHeader = mapper.lookupOrNull(block);
2201  if (!newHeader)
2202  return emitError(loc, "failed control flow structurization: nested "
2203  "loop header block should be remapped!");
2204 
2205  Block *newContinue = it->second.continueBlock;
2206  if (newContinue) {
2207  newContinue = mapper.lookupOrNull(newContinue);
2208  if (!newContinue)
2209  return emitError(loc, "failed control flow structurization: nested "
2210  "loop continue block should be remapped!");
2211  }
2212 
2213  Block *newMerge = it->second.mergeBlock;
2214  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2215  newMerge = mappedTo;
2216 
2217  // The iterator should be erased before adding a new entry into
2218  // blockMergeInfo to avoid iterator invalidation.
2219  blockMergeInfo.erase(it);
2220  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2221  newContinue);
2222  }
2223 
2224  return WalkResult::advance();
2225  };
2226 
2227  if (block->walk(updateMergeInfo).wasInterrupted())
2228  return failure();
2229 
2230  // The structured selection/loop's entry block does not have arguments.
2231  // If the function's header block is also part of the structured control
2232  // flow, we cannot just simply erase it because it may contain arguments
2233  // matching the function signature and used by the cloned blocks.
2234  if (isFnEntryBlock(block)) {
2235  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2236  << " to only contain a spirv.Branch op\n");
2237  // Still keep the function entry block for the potential block arguments,
2238  // but replace all ops inside with a branch to the merge block.
2239  block->clear();
2240  builder.setInsertionPointToEnd(block);
2241  builder.create<spirv::BranchOp>(location, mergeBlock);
2242  } else {
2243  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2244  block->erase();
2245  }
2246  }
2247 
2248  LLVM_DEBUG(logger.startLine()
2249  << "[cf] after structurizing construct with header block "
2250  << headerBlock << ":\n"
2251  << *op << "\n");
2252 
2253  return success();
2254 }
2255 
2256 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2257  LLVM_DEBUG({
2258  logger.startLine()
2259  << "//----- [phi] start wiring up block arguments -----//\n";
2260  logger.indent();
2261  });
2262 
2263  OpBuilder::InsertionGuard guard(opBuilder);
2264 
2265  for (const auto &info : blockPhiInfo) {
2266  Block *block = info.first.first;
2267  Block *target = info.first.second;
2268  const BlockPhiInfo &phiInfo = info.second;
2269  LLVM_DEBUG({
2270  logger.startLine() << "[phi] block " << block << "\n";
2271  logger.startLine() << "[phi] before creating block argument:\n";
2272  block->getParentOp()->print(logger.getOStream());
2273  logger.startLine() << "\n";
2274  });
2275 
2276  // Set insertion point to before this block's terminator early because we
2277  // may materialize ops via getValue() call.
2278  auto *op = block->getTerminator();
2279  opBuilder.setInsertionPoint(op);
2280 
2281  SmallVector<Value, 4> blockArgs;
2282  blockArgs.reserve(phiInfo.size());
2283  for (uint32_t valueId : phiInfo) {
2284  if (Value value = getValue(valueId)) {
2285  blockArgs.push_back(value);
2286  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2287  << " id = " << valueId << "\n");
2288  } else {
2289  return emitError(unknownLoc, "OpPhi references undefined value!");
2290  }
2291  }
2292 
2293  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2294  // Replace the previous branch op with a new one with block arguments.
2295  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2296  blockArgs);
2297  branchOp.erase();
2298  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2299  assert((branchCondOp.getTrueBlock() == target ||
2300  branchCondOp.getFalseBlock() == target) &&
2301  "expected target to be either the true or false target");
2302  if (target == branchCondOp.getTrueTarget())
2303  opBuilder.create<spirv::BranchConditionalOp>(
2304  branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2305  branchCondOp.getFalseBlockArguments(),
2306  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2307  branchCondOp.getFalseTarget());
2308  else
2309  opBuilder.create<spirv::BranchConditionalOp>(
2310  branchCondOp.getLoc(), branchCondOp.getCondition(),
2311  branchCondOp.getTrueBlockArguments(), blockArgs,
2312  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2313  branchCondOp.getFalseBlock());
2314 
2315  branchCondOp.erase();
2316  } else {
2317  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2318  }
2319 
2320  LLVM_DEBUG({
2321  logger.startLine() << "[phi] after creating block argument:\n";
2322  block->getParentOp()->print(logger.getOStream());
2323  logger.startLine() << "\n";
2324  });
2325  }
2326  blockPhiInfo.clear();
2327 
2328  LLVM_DEBUG({
2329  logger.unindent();
2330  logger.startLine()
2331  << "//--- [phi] completed wiring up block arguments ---//\n";
2332  });
2333  return success();
2334 }
2335 
2336 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2337  // Create a copy, so we can modify keys in the original.
2338  BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2339  for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2340  it != e; ++it) {
2341  auto &[block, mergeInfo] = *it;
2342 
2343  // Skip processing loop regions. For loop regions continueBlock is non-null.
2344  if (mergeInfo.continueBlock)
2345  continue;
2346 
2347  if (!block->mightHaveTerminator())
2348  continue;
2349 
2350  Operation *terminator = block->getTerminator();
2351  assert(terminator);
2352 
2353  if (!isa<spirv::BranchConditionalOp>(terminator))
2354  continue;
2355 
2356  // Check if the current header block is a merge block of another construct.
2357  bool splitHeaderMergeBlock = false;
2358  for (const auto &[_, mergeInfo] : blockMergeInfo) {
2359  if (mergeInfo.mergeBlock == block)
2360  splitHeaderMergeBlock = true;
2361  }
2362 
2363  // Do not split a block that only contains a conditional branch, unless it
2364  // is also a merge block of another construct - in that case we want to
2365  // split the block. We do not want two constructs to share header / merge
2366  // block.
2367  if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2368  Block *newBlock = block->splitBlock(terminator);
2369  OpBuilder builder(block, block->end());
2370  builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
2371 
2372  // After splitting we need to update the map to use the new block as a
2373  // header.
2374  blockMergeInfo.erase(block);
2375  blockMergeInfo.try_emplace(newBlock, mergeInfo);
2376  }
2377  }
2378 
2379  return success();
2380 }
2381 
2382 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2383  if (!options.enableControlFlowStructurization) {
2384  LLVM_DEBUG(
2385  {
2386  logger.startLine()
2387  << "//----- [cf] skip structurizing control flow -----//\n";
2388  logger.indent();
2389  });
2390  return success();
2391  }
2392 
2393  LLVM_DEBUG({
2394  logger.startLine()
2395  << "//----- [cf] start structurizing control flow -----//\n";
2396  logger.indent();
2397  });
2398 
2399  LLVM_DEBUG({
2400  logger.startLine() << "[cf] split conditional blocks\n";
2401  logger.startLine() << "\n";
2402  });
2403 
2404  if (failed(splitConditionalBlocks())) {
2405  return failure();
2406  }
2407 
2408  // TODO: This loop is non-deterministic. Iteration order may vary between runs
2409  // for the same shader as the key to the map is a pointer. See:
2410  // https://github.com/llvm/llvm-project/issues/128547
2411  while (!blockMergeInfo.empty()) {
2412  Block *headerBlock = blockMergeInfo.begin()->first;
2413  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2414 
2415  LLVM_DEBUG({
2416  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2417  headerBlock->print(logger.getOStream());
2418  logger.startLine() << "\n";
2419  });
2420 
2421  auto *mergeBlock = mergeInfo.mergeBlock;
2422  assert(mergeBlock && "merge block cannot be nullptr");
2423  if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2424  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2425  LLVM_DEBUG({
2426  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2427  mergeBlock->print(logger.getOStream());
2428  logger.startLine() << "\n";
2429  });
2430 
2431  auto *continueBlock = mergeInfo.continueBlock;
2432  LLVM_DEBUG(if (continueBlock) {
2433  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2434  continueBlock->print(logger.getOStream());
2435  logger.startLine() << "\n";
2436  });
2437  // Erase this case before calling into structurizer, who will update
2438  // blockMergeInfo.
2439  blockMergeInfo.erase(blockMergeInfo.begin());
2440  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2441  blockMergeInfo, headerBlock,
2442  mergeBlock, continueBlock
2443 #ifndef NDEBUG
2444  ,
2445  logger
2446 #endif
2447  );
2448  if (failed(structurizer.structurize()))
2449  return failure();
2450  }
2451 
2452  LLVM_DEBUG({
2453  logger.unindent();
2454  logger.startLine()
2455  << "//--- [cf] completed structurizing control flow ---//\n";
2456  });
2457  return success();
2458 }
2459 
2460 //===----------------------------------------------------------------------===//
2461 // Debug
2462 //===----------------------------------------------------------------------===//
2463 
2464 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2465  if (!debugLine)
2466  return unknownLoc;
2467 
2468  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2469  if (fileName.empty())
2470  fileName = "<unknown>";
2471  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2472  debugLine->column);
2473 }
2474 
2475 LogicalResult
2476 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2477  // According to SPIR-V spec:
2478  // "This location information applies to the instructions physically
2479  // following this instruction, up to the first occurrence of any of the
2480  // following: the next end of block, the next OpLine instruction, or the next
2481  // OpNoLine instruction."
2482  if (operands.size() != 3)
2483  return emitError(unknownLoc, "OpLine must have 3 operands");
2484  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2485  return success();
2486 }
2487 
2488 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2489 
2490 LogicalResult
2491 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2492  if (operands.size() < 2)
2493  return emitError(unknownLoc, "OpString needs at least 2 operands");
2494 
2495  if (!debugInfoMap.lookup(operands[0]).empty())
2496  return emitError(unknownLoc,
2497  "duplicate debug string found for result <id> ")
2498  << operands[0];
2499 
2500  unsigned wordIndex = 1;
2501  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2502  if (wordIndex != operands.size())
2503  return emitError(unknownLoc,
2504  "unexpected trailing words in OpString instruction");
2505 
2506  debugInfoMap[operands[0]] = debugString;
2507  return success();
2508 }
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
int64_t rows
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:310
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
void print(raw_ostream &os)
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:252
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
void push_back(Operation *op)
Definition: Block.h:149
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:161
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:695
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockListType::iterator iterator
Definition: Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:235
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:427
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:484
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:753
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:995
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:982
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
Definition: Deserializer.h:61
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.