MLIR  20.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context)
53  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54  module(createModuleOp()), opBuilder(module->getRegion())
55 #ifndef NDEBUG
56  ,
57  logger(llvm::dbgs())
58 #endif
59 {
60 }
61 
63  LLVM_DEBUG({
64  logger.resetIndent();
65  logger.startLine()
66  << "//+++---------- start deserialization ----------+++//\n";
67  });
68 
69  if (failed(processHeader()))
70  return failure();
71 
72  spirv::Opcode opcode = spirv::Opcode::OpNop;
73  ArrayRef<uint32_t> operands;
74  auto binarySize = binary.size();
75  while (curOffset < binarySize) {
76  // Slice the next instruction out and populate `opcode` and `operands`.
77  // Internally this also updates `curOffset`.
78  if (failed(sliceInstruction(opcode, operands)))
79  return failure();
80 
81  if (failed(processInstruction(opcode, operands)))
82  return failure();
83  }
84 
85  assert(curOffset == binarySize &&
86  "deserializer should never index beyond the binary end");
87 
88  for (auto &deferred : deferredInstructions) {
89  if (failed(processInstruction(deferred.first, deferred.second, false))) {
90  return failure();
91  }
92  }
93 
94  attachVCETriple();
95 
96  LLVM_DEBUG(logger.startLine()
97  << "//+++-------- completed deserialization --------+++//\n");
98  return success();
99 }
100 
102  return std::move(module);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Module structure
107 //===----------------------------------------------------------------------===//
108 
109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
110  OpBuilder builder(context);
111  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112  spirv::ModuleOp::build(builder, state);
113  return cast<spirv::ModuleOp>(Operation::create(state));
114 }
115 
116 LogicalResult spirv::Deserializer::processHeader() {
117  if (binary.size() < spirv::kHeaderWordCount)
118  return emitError(unknownLoc,
119  "SPIR-V binary module must have a 5-word header");
120 
121  if (binary[0] != spirv::kMagicNumber)
122  return emitError(unknownLoc, "incorrect magic number");
123 
124  // Version number bytes: 0 | major number | minor number | 0
125  uint32_t majorVersion = (binary[1] << 8) >> 24;
126  uint32_t minorVersion = (binary[1] << 16) >> 24;
127  if (majorVersion == 1) {
128  switch (minorVersion) {
129 #define MIN_VERSION_CASE(v) \
130  case v: \
131  version = spirv::Version::V_1_##v; \
132  break
133 
134  MIN_VERSION_CASE(0);
135  MIN_VERSION_CASE(1);
136  MIN_VERSION_CASE(2);
137  MIN_VERSION_CASE(3);
138  MIN_VERSION_CASE(4);
139  MIN_VERSION_CASE(5);
140 #undef MIN_VERSION_CASE
141  default:
142  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
143  << minorVersion;
144  }
145  } else {
146  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
147  << majorVersion;
148  }
149 
150  // TODO: generator number, bound, schema
151  curOffset = spirv::kHeaderWordCount;
152  return success();
153 }
154 
155 LogicalResult
156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
157  if (operands.size() != 1)
158  return emitError(unknownLoc, "OpMemoryModel must have one parameter");
159 
160  auto cap = spirv::symbolizeCapability(operands[0]);
161  if (!cap)
162  return emitError(unknownLoc, "unknown capability: ") << operands[0];
163 
164  capabilities.insert(*cap);
165  return success();
166 }
167 
168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
169  if (words.empty()) {
170  return emitError(
171  unknownLoc,
172  "OpExtension must have a literal string for the extension name");
173  }
174 
175  unsigned wordIndex = 0;
176  StringRef extName = decodeStringLiteral(words, wordIndex);
177  if (wordIndex != words.size())
178  return emitError(unknownLoc,
179  "unexpected trailing words in OpExtension instruction");
180  auto ext = spirv::symbolizeExtension(extName);
181  if (!ext)
182  return emitError(unknownLoc, "unknown extension: ") << extName;
183 
184  extensions.insert(*ext);
185  return success();
186 }
187 
188 LogicalResult
189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
190  if (words.size() < 2) {
191  return emitError(unknownLoc,
192  "OpExtInstImport must have a result <id> and a literal "
193  "string for the extended instruction set name");
194  }
195 
196  unsigned wordIndex = 1;
197  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
198  if (wordIndex != words.size()) {
199  return emitError(unknownLoc,
200  "unexpected trailing words in OpExtInstImport");
201  }
202  return success();
203 }
204 
205 void spirv::Deserializer::attachVCETriple() {
206  (*module)->setAttr(
207  spirv::ModuleOp::getVCETripleAttrName(),
208  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
209  extensions.getArrayRef(), context));
210 }
211 
212 LogicalResult
213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
214  if (operands.size() != 2)
215  return emitError(unknownLoc, "OpMemoryModel must have two operands");
216 
217  (*module)->setAttr(
218  module->getAddressingModelAttrName(),
219  opBuilder.getAttr<spirv::AddressingModelAttr>(
220  static_cast<spirv::AddressingModel>(operands.front())));
221 
222  (*module)->setAttr(module->getMemoryModelAttrName(),
223  opBuilder.getAttr<spirv::MemoryModelAttr>(
224  static_cast<spirv::MemoryModel>(operands.back())));
225 
226  return success();
227 }
228 
229 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
231  Location loc, OpBuilder &opBuilder,
233  StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
234  if (words.size() != 4) {
235  return emitError(loc, "OpDecoration with ")
236  << decorationName << "needs a cache control integer literal and a "
237  << cacheControlKind << " cache control literal";
238  }
239  unsigned cacheLevel = words[2];
240  auto cacheControlAttr = static_cast<EnumTy>(words[3]);
241  auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
243  if (auto attrList =
244  llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
245  llvm::append_range(attrs, attrList);
246  attrs.push_back(value);
247  decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
248  return success();
249 }
250 
251 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
252  // TODO: This function should also be auto-generated. For now, since only a
253  // few decorations are processed/handled in a meaningful manner, going with a
254  // manual implementation.
255  if (words.size() < 2) {
256  return emitError(
257  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
258  }
259  auto decorationName =
260  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
261  if (decorationName.empty()) {
262  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
263  }
264  auto symbol = getSymbolDecoration(decorationName);
265  switch (static_cast<spirv::Decoration>(words[1])) {
266  case spirv::Decoration::FPFastMathMode:
267  if (words.size() != 3) {
268  return emitError(unknownLoc, "OpDecorate with ")
269  << decorationName << " needs a single integer literal";
270  }
271  decorations[words[0]].set(
272  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
273  static_cast<FPFastMathMode>(words[2])));
274  break;
275  case spirv::Decoration::FPRoundingMode:
276  if (words.size() != 3) {
277  return emitError(unknownLoc, "OpDecorate with ")
278  << decorationName << " needs a single integer literal";
279  }
280  decorations[words[0]].set(
281  symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
282  static_cast<FPRoundingMode>(words[2])));
283  break;
284  case spirv::Decoration::DescriptorSet:
285  case spirv::Decoration::Binding:
286  if (words.size() != 3) {
287  return emitError(unknownLoc, "OpDecorate with ")
288  << decorationName << " needs a single integer literal";
289  }
290  decorations[words[0]].set(
291  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
292  break;
293  case spirv::Decoration::BuiltIn:
294  if (words.size() != 3) {
295  return emitError(unknownLoc, "OpDecorate with ")
296  << decorationName << " needs a single integer literal";
297  }
298  decorations[words[0]].set(
299  symbol, opBuilder.getStringAttr(
300  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
301  break;
302  case spirv::Decoration::ArrayStride:
303  if (words.size() != 3) {
304  return emitError(unknownLoc, "OpDecorate with ")
305  << decorationName << " needs a single integer literal";
306  }
307  typeDecorations[words[0]] = words[2];
308  break;
309  case spirv::Decoration::LinkageAttributes: {
310  if (words.size() < 4) {
311  return emitError(unknownLoc, "OpDecorate with ")
312  << decorationName
313  << " needs at least 1 string and 1 integer literal";
314  }
315  // LinkageAttributes has two parameters ["linkageName", linkageType]
316  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
317  // "linkageName" is a stringliteral encoded as uint32_t,
318  // hence the size of name is variable length which results in words.size()
319  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
320  // 3 + ceildiv(strlen(name), 4).
321  unsigned wordIndex = 2;
322  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
323  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
324  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
325  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
326  StringAttr::get(context, linkageName), linkageTypeAttr);
327  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
328  break;
329  }
330  case spirv::Decoration::Aliased:
331  case spirv::Decoration::AliasedPointer:
332  case spirv::Decoration::Block:
333  case spirv::Decoration::BufferBlock:
334  case spirv::Decoration::Flat:
335  case spirv::Decoration::NonReadable:
336  case spirv::Decoration::NonWritable:
337  case spirv::Decoration::NoPerspective:
338  case spirv::Decoration::NoSignedWrap:
339  case spirv::Decoration::NoUnsignedWrap:
340  case spirv::Decoration::RelaxedPrecision:
341  case spirv::Decoration::Restrict:
342  case spirv::Decoration::RestrictPointer:
343  case spirv::Decoration::NoContraction:
344  case spirv::Decoration::Constant:
345  if (words.size() != 2) {
346  return emitError(unknownLoc, "OpDecoration with ")
347  << decorationName << "needs a single target <id>";
348  }
349  // Block decoration does not affect spirv.struct type, but is still stored
350  // for verification.
351  // TODO: Update StructType to contain this information since
352  // it is needed for many validation rules.
353  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
354  break;
355  case spirv::Decoration::Location:
356  case spirv::Decoration::SpecId:
357  if (words.size() != 3) {
358  return emitError(unknownLoc, "OpDecoration with ")
359  << decorationName << "needs a single integer literal";
360  }
361  decorations[words[0]].set(
362  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
363  break;
364  case spirv::Decoration::CacheControlLoadINTEL: {
365  LogicalResult res = deserializeCacheControlDecoration<
366  CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
368  "load");
369  if (failed(res))
370  return res;
371  break;
372  }
373  case spirv::Decoration::CacheControlStoreINTEL: {
374  LogicalResult res = deserializeCacheControlDecoration<
375  CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
377  "store");
378  if (failed(res))
379  return res;
380  break;
381  }
382  default:
383  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
384  }
385  return success();
386 }
387 
388 LogicalResult
389 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
390  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
391  if (words.size() < 3) {
392  return emitError(unknownLoc,
393  "OpMemberDecorate must have at least 3 operands");
394  }
395 
396  auto decoration = static_cast<spirv::Decoration>(words[2]);
397  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
398  return emitError(unknownLoc,
399  " missing offset specification in OpMemberDecorate with "
400  "Offset decoration");
401  }
402  ArrayRef<uint32_t> decorationOperands;
403  if (words.size() > 3) {
404  decorationOperands = words.slice(3);
405  }
406  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
407  return success();
408 }
409 
410 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
411  if (words.size() < 3) {
412  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
413  }
414  unsigned wordIndex = 2;
415  auto name = decodeStringLiteral(words, wordIndex);
416  if (wordIndex != words.size()) {
417  return emitError(unknownLoc,
418  "unexpected trailing words in OpMemberName instruction");
419  }
420  memberNameMap[words[0]][words[1]] = name;
421  return success();
422 }
423 
424 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
425  uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
426  if (!decorations.contains(argID)) {
427  argAttrs[argIndex] = DictionaryAttr::get(context, {});
428  return success();
429  }
430 
431  spirv::DecorationAttr foundDecorationAttr;
432  for (NamedAttribute decAttr : decorations[argID]) {
433  for (auto decoration :
434  {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435  spirv::Decoration::AliasedPointer,
436  spirv::Decoration::RestrictPointer}) {
437 
438  if (decAttr.getName() !=
439  getSymbolDecoration(stringifyDecoration(decoration)))
440  continue;
441 
442  if (foundDecorationAttr)
443  return emitError(unknownLoc,
444  "more than one Aliased/Restrict decorations for "
445  "function argument with result <id> ")
446  << argID;
447 
448  foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
449  break;
450  }
451  }
452 
453  if (!foundDecorationAttr)
454  return emitError(unknownLoc, "unimplemented decoration support for "
455  "function argument with result <id> ")
456  << argID;
457 
458  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
459  foundDecorationAttr);
460  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
461  return success();
462 }
463 
464 LogicalResult
465 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
466  if (curFunction) {
467  return emitError(unknownLoc, "found function inside function");
468  }
469 
470  // Get the result type
471  if (operands.size() != 4) {
472  return emitError(unknownLoc, "OpFunction must have 4 parameters");
473  }
474  Type resultType = getType(operands[0]);
475  if (!resultType) {
476  return emitError(unknownLoc, "undefined result type from <id> ")
477  << operands[0];
478  }
479 
480  uint32_t fnID = operands[1];
481  if (funcMap.count(fnID)) {
482  return emitError(unknownLoc, "duplicate function definition/declaration");
483  }
484 
485  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
486  if (!fnControl) {
487  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
488  }
489 
490  Type fnType = getType(operands[3]);
491  if (!fnType || !isa<FunctionType>(fnType)) {
492  return emitError(unknownLoc, "unknown function type from <id> ")
493  << operands[3];
494  }
495  auto functionType = cast<FunctionType>(fnType);
496 
497  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
498  (functionType.getNumResults() == 1 &&
499  functionType.getResult(0) != resultType)) {
500  return emitError(unknownLoc, "mismatch in function type ")
501  << functionType << " and return type " << resultType << " specified";
502  }
503 
504  std::string fnName = getFunctionSymbol(fnID);
505  auto funcOp = opBuilder.create<spirv::FuncOp>(
506  unknownLoc, fnName, functionType, fnControl.value());
507  // Processing other function attributes.
508  if (decorations.count(fnID)) {
509  for (auto attr : decorations[fnID].getAttrs()) {
510  funcOp->setAttr(attr.getName(), attr.getValue());
511  }
512  }
513  curFunction = funcMap[fnID] = funcOp;
514  auto *entryBlock = funcOp.addEntryBlock();
515  LLVM_DEBUG({
516  logger.startLine()
517  << "//===-------------------------------------------===//\n";
518  logger.startLine() << "[fn] name: " << fnName << "\n";
519  logger.startLine() << "[fn] type: " << fnType << "\n";
520  logger.startLine() << "[fn] ID: " << fnID << "\n";
521  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
522  logger.indent();
523  });
524 
525  SmallVector<Attribute> argAttrs;
526  argAttrs.resize(functionType.getNumInputs());
527 
528  // Parse the op argument instructions
529  if (functionType.getNumInputs()) {
530  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
531  auto argType = functionType.getInput(i);
532  spirv::Opcode opcode = spirv::Opcode::OpNop;
533  ArrayRef<uint32_t> operands;
534  if (failed(sliceInstruction(opcode, operands,
535  spirv::Opcode::OpFunctionParameter))) {
536  return failure();
537  }
538  if (opcode != spirv::Opcode::OpFunctionParameter) {
539  return emitError(
540  unknownLoc,
541  "missing OpFunctionParameter instruction for argument ")
542  << i;
543  }
544  if (operands.size() != 2) {
545  return emitError(
546  unknownLoc,
547  "expected result type and result <id> for OpFunctionParameter");
548  }
549  auto argDefinedType = getType(operands[0]);
550  if (!argDefinedType || argDefinedType != argType) {
551  return emitError(unknownLoc,
552  "mismatch in argument type between function type "
553  "definition ")
554  << functionType << " and argument type definition "
555  << argDefinedType << " at argument " << i;
556  }
557  if (getValue(operands[1])) {
558  return emitError(unknownLoc, "duplicate definition of result <id> ")
559  << operands[1];
560  }
561  if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
562  return failure();
563  }
564 
565  auto argValue = funcOp.getArgument(i);
566  valueMap[operands[1]] = argValue;
567  }
568  }
569 
570  if (llvm::any_of(argAttrs, [](Attribute attr) {
571  auto argAttr = cast<DictionaryAttr>(attr);
572  return !argAttr.empty();
573  }))
574  funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
575 
576  // entryBlock is needed to access the arguments, Once that is done, we can
577  // erase the block for functions with 'Import' LinkageAttributes, since these
578  // are essentially function declarations, so they have no body.
579  auto linkageAttr = funcOp.getLinkageAttributes();
580  auto hasImportLinkage =
581  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
582  spirv::LinkageType::Import);
583  if (hasImportLinkage)
584  funcOp.eraseBody();
585 
586  // RAII guard to reset the insertion point to the module's region after
587  // deserializing the body of this function.
588  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
589 
590  spirv::Opcode opcode = spirv::Opcode::OpNop;
591  ArrayRef<uint32_t> instOperands;
592 
593  // Special handling for the entry block. We need to make sure it starts with
594  // an OpLabel instruction. The entry block takes the same parameters as the
595  // function. All other blocks do not take any parameter. We have already
596  // created the entry block, here we need to register it to the correct label
597  // <id>.
598  if (failed(sliceInstruction(opcode, instOperands,
599  spirv::Opcode::OpFunctionEnd))) {
600  return failure();
601  }
602  if (opcode == spirv::Opcode::OpFunctionEnd) {
603  return processFunctionEnd(instOperands);
604  }
605  if (opcode != spirv::Opcode::OpLabel) {
606  return emitError(unknownLoc, "a basic block must start with OpLabel");
607  }
608  if (instOperands.size() != 1) {
609  return emitError(unknownLoc, "OpLabel should only have result <id>");
610  }
611  blockMap[instOperands[0]] = entryBlock;
612  if (failed(processLabel(instOperands))) {
613  return failure();
614  }
615 
616  // Then process all the other instructions in the function until we hit
617  // OpFunctionEnd.
618  while (succeeded(sliceInstruction(opcode, instOperands,
619  spirv::Opcode::OpFunctionEnd)) &&
620  opcode != spirv::Opcode::OpFunctionEnd) {
621  if (failed(processInstruction(opcode, instOperands))) {
622  return failure();
623  }
624  }
625  if (opcode != spirv::Opcode::OpFunctionEnd) {
626  return failure();
627  }
628 
629  return processFunctionEnd(instOperands);
630 }
631 
632 LogicalResult
633 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
634  // Process OpFunctionEnd.
635  if (!operands.empty()) {
636  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
637  }
638 
639  // Wire up block arguments from OpPhi instructions.
640  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
641  // ops.
642  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
643  return failure();
644  }
645 
646  curBlock = nullptr;
647  curFunction = std::nullopt;
648 
649  LLVM_DEBUG({
650  logger.unindent();
651  logger.startLine()
652  << "//===-------------------------------------------===//\n";
653  });
654  return success();
655 }
656 
657 std::optional<std::pair<Attribute, Type>>
658 spirv::Deserializer::getConstant(uint32_t id) {
659  auto constIt = constantMap.find(id);
660  if (constIt == constantMap.end())
661  return std::nullopt;
662  return constIt->getSecond();
663 }
664 
665 std::optional<spirv::SpecConstOperationMaterializationInfo>
666 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
667  auto constIt = specConstOperationMap.find(id);
668  if (constIt == specConstOperationMap.end())
669  return std::nullopt;
670  return constIt->getSecond();
671 }
672 
673 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
674  auto funcName = nameMap.lookup(id).str();
675  if (funcName.empty()) {
676  funcName = "spirv_fn_" + std::to_string(id);
677  }
678  return funcName;
679 }
680 
681 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
682  auto constName = nameMap.lookup(id).str();
683  if (constName.empty()) {
684  constName = "spirv_spec_const_" + std::to_string(id);
685  }
686  return constName;
687 }
688 
689 spirv::SpecConstantOp
690 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
691  TypedAttr defaultValue) {
692  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
693  auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
694  defaultValue);
695  if (decorations.count(resultID)) {
696  for (auto attr : decorations[resultID].getAttrs())
697  op->setAttr(attr.getName(), attr.getValue());
698  }
699  specConstMap[resultID] = op;
700  return op;
701 }
702 
703 LogicalResult
704 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
705  unsigned wordIndex = 0;
706  if (operands.size() < 3) {
707  return emitError(
708  unknownLoc,
709  "OpVariable needs at least 3 operands, type, <id> and storage class");
710  }
711 
712  // Result Type.
713  auto type = getType(operands[wordIndex]);
714  if (!type) {
715  return emitError(unknownLoc, "unknown result type <id> : ")
716  << operands[wordIndex];
717  }
718  auto ptrType = dyn_cast<spirv::PointerType>(type);
719  if (!ptrType) {
720  return emitError(unknownLoc,
721  "expected a result type <id> to be a spirv.ptr, found : ")
722  << type;
723  }
724  wordIndex++;
725 
726  // Result <id>.
727  auto variableID = operands[wordIndex];
728  auto variableName = nameMap.lookup(variableID).str();
729  if (variableName.empty()) {
730  variableName = "spirv_var_" + std::to_string(variableID);
731  }
732  wordIndex++;
733 
734  // Storage class.
735  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
736  if (ptrType.getStorageClass() != storageClass) {
737  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
738  << type << " and that specified in OpVariable instruction : "
739  << stringifyStorageClass(storageClass);
740  }
741  wordIndex++;
742 
743  // Initializer.
744  FlatSymbolRefAttr initializer = nullptr;
745 
746  if (wordIndex < operands.size()) {
747  Operation *op = nullptr;
748 
749  if (auto initOp = getGlobalVariable(operands[wordIndex]))
750  op = initOp;
751  else if (auto initOp = getSpecConstant(operands[wordIndex]))
752  op = initOp;
753  else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
754  op = initOp;
755  else
756  return emitError(unknownLoc, "unknown <id> ")
757  << operands[wordIndex] << "used as initializer";
758 
759  initializer = SymbolRefAttr::get(op);
760  wordIndex++;
761  }
762  if (wordIndex != operands.size()) {
763  return emitError(unknownLoc,
764  "found more operands than expected when deserializing "
765  "OpVariable instruction, only ")
766  << wordIndex << " of " << operands.size() << " processed";
767  }
768  auto loc = createFileLineColLoc(opBuilder);
769  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
770  loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
771  initializer);
772 
773  // Decorations.
774  if (decorations.count(variableID)) {
775  for (auto attr : decorations[variableID].getAttrs())
776  varOp->setAttr(attr.getName(), attr.getValue());
777  }
778  globalVariableMap[variableID] = varOp;
779  return success();
780 }
781 
782 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
783  auto constInfo = getConstant(id);
784  if (!constInfo) {
785  return nullptr;
786  }
787  return dyn_cast<IntegerAttr>(constInfo->first);
788 }
789 
790 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
791  if (operands.size() < 2) {
792  return emitError(unknownLoc, "OpName needs at least 2 operands");
793  }
794  if (!nameMap.lookup(operands[0]).empty()) {
795  return emitError(unknownLoc, "duplicate name found for result <id> ")
796  << operands[0];
797  }
798  unsigned wordIndex = 1;
799  StringRef name = decodeStringLiteral(operands, wordIndex);
800  if (wordIndex != operands.size()) {
801  return emitError(unknownLoc,
802  "unexpected trailing words in OpName instruction");
803  }
804  nameMap[operands[0]] = name;
805  return success();
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // Type
810 //===----------------------------------------------------------------------===//
811 
812 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
813  ArrayRef<uint32_t> operands) {
814  if (operands.empty()) {
815  return emitError(unknownLoc, "type instruction with opcode ")
816  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
817  }
818 
819  /// TODO: Types might be forward declared in some instructions and need to be
820  /// handled appropriately.
821  if (typeMap.count(operands[0])) {
822  return emitError(unknownLoc, "duplicate definition for result <id> ")
823  << operands[0];
824  }
825 
826  switch (opcode) {
827  case spirv::Opcode::OpTypeVoid:
828  if (operands.size() != 1)
829  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
830  typeMap[operands[0]] = opBuilder.getNoneType();
831  break;
832  case spirv::Opcode::OpTypeBool:
833  if (operands.size() != 1)
834  return emitError(unknownLoc, "OpTypeBool must have no parameters");
835  typeMap[operands[0]] = opBuilder.getI1Type();
836  break;
837  case spirv::Opcode::OpTypeInt: {
838  if (operands.size() != 3)
839  return emitError(
840  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
841 
842  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
843  // to preserve or validate.
844  // 0 indicates unsigned, or no signedness semantics
845  // 1 indicates signed semantics."
846  //
847  // So we cannot differentiate signless and unsigned integers; always use
848  // signless semantics for such cases.
849  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
850  : IntegerType::SignednessSemantics::Signless;
851  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
852  } break;
853  case spirv::Opcode::OpTypeFloat: {
854  if (operands.size() != 2)
855  return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
856 
857  Type floatTy;
858  switch (operands[1]) {
859  case 16:
860  floatTy = opBuilder.getF16Type();
861  break;
862  case 32:
863  floatTy = opBuilder.getF32Type();
864  break;
865  case 64:
866  floatTy = opBuilder.getF64Type();
867  break;
868  default:
869  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
870  << operands[1];
871  }
872  typeMap[operands[0]] = floatTy;
873  } break;
874  case spirv::Opcode::OpTypeVector: {
875  if (operands.size() != 3) {
876  return emitError(
877  unknownLoc,
878  "OpTypeVector must have element type and count parameters");
879  }
880  Type elementTy = getType(operands[1]);
881  if (!elementTy) {
882  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
883  << operands[1];
884  }
885  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
886  } break;
887  case spirv::Opcode::OpTypePointer: {
888  return processOpTypePointer(operands);
889  } break;
890  case spirv::Opcode::OpTypeArray:
891  return processArrayType(operands);
892  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
893  return processCooperativeMatrixTypeKHR(operands);
894  case spirv::Opcode::OpTypeFunction:
895  return processFunctionType(operands);
896  case spirv::Opcode::OpTypeImage:
897  return processImageType(operands);
898  case spirv::Opcode::OpTypeSampledImage:
899  return processSampledImageType(operands);
900  case spirv::Opcode::OpTypeRuntimeArray:
901  return processRuntimeArrayType(operands);
902  case spirv::Opcode::OpTypeStruct:
903  return processStructType(operands);
904  case spirv::Opcode::OpTypeMatrix:
905  return processMatrixType(operands);
906  default:
907  return emitError(unknownLoc, "unhandled type instruction");
908  }
909  return success();
910 }
911 
912 LogicalResult
913 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
914  if (operands.size() != 3)
915  return emitError(unknownLoc, "OpTypePointer must have two parameters");
916 
917  auto pointeeType = getType(operands[2]);
918  if (!pointeeType)
919  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
920  << operands[2];
921 
922  uint32_t typePointerID = operands[0];
923  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
924  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
925 
926  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
927  deferredStructIt != std::end(deferredStructTypesInfos);) {
928  for (auto *unresolvedMemberIt =
929  std::begin(deferredStructIt->unresolvedMemberTypes);
930  unresolvedMemberIt !=
931  std::end(deferredStructIt->unresolvedMemberTypes);) {
932  if (unresolvedMemberIt->first == typePointerID) {
933  // The newly constructed pointer type can resolve one of the
934  // deferred struct type members; update the memberTypes list and
935  // clean the unresolvedMemberTypes list accordingly.
936  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
937  typeMap[typePointerID];
938  unresolvedMemberIt =
939  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
940  } else {
941  ++unresolvedMemberIt;
942  }
943  }
944 
945  if (deferredStructIt->unresolvedMemberTypes.empty()) {
946  // All deferred struct type members are now resolved, set the struct body.
947  auto structType = deferredStructIt->deferredStructType;
948 
949  assert(structType && "expected a spirv::StructType");
950  assert(structType.isIdentified() && "expected an indentified struct");
951 
952  if (failed(structType.trySetBody(
953  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
954  deferredStructIt->memberDecorationsInfo)))
955  return failure();
956 
957  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
958  } else {
959  ++deferredStructIt;
960  }
961  }
962 
963  return success();
964 }
965 
966 LogicalResult
967 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
968  if (operands.size() != 3) {
969  return emitError(unknownLoc,
970  "OpTypeArray must have element type and count parameters");
971  }
972 
973  Type elementTy = getType(operands[1]);
974  if (!elementTy) {
975  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
976  << operands[1];
977  }
978 
979  unsigned count = 0;
980  // TODO: The count can also come frome a specialization constant.
981  auto countInfo = getConstant(operands[2]);
982  if (!countInfo) {
983  return emitError(unknownLoc, "OpTypeArray count <id> ")
984  << operands[2] << "can only come from normal constant right now";
985  }
986 
987  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
988  count = intVal.getValue().getZExtValue();
989  } else {
990  return emitError(unknownLoc, "OpTypeArray count must come from a "
991  "scalar integer constant instruction");
992  }
993 
994  typeMap[operands[0]] = spirv::ArrayType::get(
995  elementTy, count, typeDecorations.lookup(operands[0]));
996  return success();
997 }
998 
999 LogicalResult
1000 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1001  assert(!operands.empty() && "No operands for processing function type");
1002  if (operands.size() == 1) {
1003  return emitError(unknownLoc, "missing return type for OpTypeFunction");
1004  }
1005  auto returnType = getType(operands[1]);
1006  if (!returnType) {
1007  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1008  }
1009  SmallVector<Type, 1> argTypes;
1010  for (size_t i = 2, e = operands.size(); i < e; ++i) {
1011  auto ty = getType(operands[i]);
1012  if (!ty) {
1013  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1014  }
1015  argTypes.push_back(ty);
1016  }
1017  ArrayRef<Type> returnTypes;
1018  if (!isVoidType(returnType)) {
1019  returnTypes = llvm::ArrayRef(returnType);
1020  }
1021  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1022  return success();
1023 }
1024 
1025 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1026  ArrayRef<uint32_t> operands) {
1027  if (operands.size() != 6) {
1028  return emitError(unknownLoc,
1029  "OpTypeCooperativeMatrixKHR must have element type, "
1030  "scope, row and column parameters, and use");
1031  }
1032 
1033  Type elementTy = getType(operands[1]);
1034  if (!elementTy) {
1035  return emitError(unknownLoc,
1036  "OpTypeCooperativeMatrixKHR references undefined <id> ")
1037  << operands[1];
1038  }
1039 
1040  std::optional<spirv::Scope> scope =
1041  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1042  if (!scope) {
1043  return emitError(
1044  unknownLoc,
1045  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1046  << operands[2];
1047  }
1048 
1049  unsigned rows = getConstantInt(operands[3]).getInt();
1050  unsigned columns = getConstantInt(operands[4]).getInt();
1051 
1052  std::optional<spirv::CooperativeMatrixUseKHR> use =
1053  spirv::symbolizeCooperativeMatrixUseKHR(
1054  getConstantInt(operands[5]).getInt());
1055  if (!use) {
1056  return emitError(
1057  unknownLoc,
1058  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1059  << operands[5];
1060  }
1061 
1062  typeMap[operands[0]] =
1063  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1064  return success();
1065 }
1066 
1067 LogicalResult
1068 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1069  if (operands.size() != 2) {
1070  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1071  }
1072  Type memberType = getType(operands[1]);
1073  if (!memberType) {
1074  return emitError(unknownLoc,
1075  "OpTypeRuntimeArray references undefined <id> ")
1076  << operands[1];
1077  }
1078  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1079  memberType, typeDecorations.lookup(operands[0]));
1080  return success();
1081 }
1082 
1083 LogicalResult
1084 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1085  // TODO: Find a way to handle identified structs when debug info is stripped.
1086 
1087  if (operands.empty()) {
1088  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1089  }
1090 
1091  if (operands.size() == 1) {
1092  // Handle empty struct.
1093  typeMap[operands[0]] =
1094  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1095  return success();
1096  }
1097 
1098  // First element is operand ID, second element is member index in the struct.
1099  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1100  SmallVector<Type, 4> memberTypes;
1101 
1102  for (auto op : llvm::drop_begin(operands, 1)) {
1103  Type memberType = getType(op);
1104  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1105 
1106  if (!memberType && !typeForwardPtr)
1107  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1108  << op;
1109 
1110  if (!memberType)
1111  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1112 
1113  memberTypes.push_back(memberType);
1114  }
1115 
1118  if (memberDecorationMap.count(operands[0])) {
1119  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1120  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1121  if (allMemberDecorations.count(memberIndex)) {
1122  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1123  // Check for offset.
1124  if (memberDecoration.first == spirv::Decoration::Offset) {
1125  // If offset info is empty, resize to the number of members;
1126  if (offsetInfo.empty()) {
1127  offsetInfo.resize(memberTypes.size());
1128  }
1129  offsetInfo[memberIndex] = memberDecoration.second[0];
1130  } else {
1131  if (!memberDecoration.second.empty()) {
1132  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
1133  memberDecoration.first,
1134  memberDecoration.second[0]);
1135  } else {
1136  memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
1137  memberDecoration.first, 0);
1138  }
1139  }
1140  }
1141  }
1142  }
1143  }
1144 
1145  uint32_t structID = operands[0];
1146  std::string structIdentifier = nameMap.lookup(structID).str();
1147 
1148  if (structIdentifier.empty()) {
1149  assert(unresolvedMemberTypes.empty() &&
1150  "didn't expect unresolved member types");
1151  typeMap[structID] =
1152  spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
1153  } else {
1154  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1155  typeMap[structID] = structTy;
1156 
1157  if (!unresolvedMemberTypes.empty())
1158  deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1159  memberTypes, offsetInfo,
1160  memberDecorationsInfo});
1161  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1162  memberDecorationsInfo)))
1163  return failure();
1164  }
1165 
1166  // TODO: Update StructType to have member name as attribute as
1167  // well.
1168  return success();
1169 }
1170 
1171 LogicalResult
1172 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1173  if (operands.size() != 3) {
1174  // Three operands are needed: result_id, column_type, and column_count
1175  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1176  " (result_id, column_type, and column_count)");
1177  }
1178  // Matrix columns must be of vector type
1179  Type elementTy = getType(operands[1]);
1180  if (!elementTy) {
1181  return emitError(unknownLoc,
1182  "OpTypeMatrix references undefined column type.")
1183  << operands[1];
1184  }
1185 
1186  uint32_t colsCount = operands[2];
1187  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1188  return success();
1189 }
1190 
1191 LogicalResult
1192 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1193  if (operands.size() != 2)
1194  return emitError(unknownLoc,
1195  "OpTypeForwardPointer instruction must have two operands");
1196 
1197  typeForwardPointerIDs.insert(operands[0]);
1198  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1199  // instruction that defines the actual type.
1200 
1201  return success();
1202 }
1203 
1204 LogicalResult
1205 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1206  // TODO: Add support for Access Qualifier.
1207  if (operands.size() != 8)
1208  return emitError(
1209  unknownLoc,
1210  "OpTypeImage with non-eight operands are not supported yet");
1211 
1212  Type elementTy = getType(operands[1]);
1213  if (!elementTy)
1214  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1215  << operands[1];
1216 
1217  auto dim = spirv::symbolizeDim(operands[2]);
1218  if (!dim)
1219  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1220  << operands[2];
1221 
1222  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1223  if (!depthInfo)
1224  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1225  << operands[3];
1226 
1227  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1228  if (!arrayedInfo)
1229  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1230  << operands[4];
1231 
1232  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1233  if (!samplingInfo)
1234  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1235 
1236  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1237  if (!samplerUseInfo)
1238  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1239  << operands[6];
1240 
1241  auto format = spirv::symbolizeImageFormat(operands[7]);
1242  if (!format)
1243  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1244  << operands[7];
1245 
1246  typeMap[operands[0]] = spirv::ImageType::get(
1247  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1248  samplingInfo.value(), samplerUseInfo.value(), format.value());
1249  return success();
1250 }
1251 
1252 LogicalResult
1253 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1254  if (operands.size() != 2)
1255  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1256 
1257  Type elementTy = getType(operands[1]);
1258  if (!elementTy)
1259  return emitError(unknownLoc,
1260  "OpTypeSampledImage references undefined <id>: ")
1261  << operands[1];
1262 
1263  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1264  return success();
1265 }
1266 
1267 //===----------------------------------------------------------------------===//
1268 // Constant
1269 //===----------------------------------------------------------------------===//
1270 
1271 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1272  bool isSpec) {
1273  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1274 
1275  if (operands.size() < 2) {
1276  return emitError(unknownLoc)
1277  << opname << " must have type <id> and result <id>";
1278  }
1279  if (operands.size() < 3) {
1280  return emitError(unknownLoc)
1281  << opname << " must have at least 1 more parameter";
1282  }
1283 
1284  Type resultType = getType(operands[0]);
1285  if (!resultType) {
1286  return emitError(unknownLoc, "undefined result type from <id> ")
1287  << operands[0];
1288  }
1289 
1290  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1291  if (bitwidth == 64) {
1292  if (operands.size() == 4) {
1293  return success();
1294  }
1295  return emitError(unknownLoc)
1296  << opname << " should have 2 parameters for 64-bit values";
1297  }
1298  if (bitwidth <= 32) {
1299  if (operands.size() == 3) {
1300  return success();
1301  }
1302 
1303  return emitError(unknownLoc)
1304  << opname
1305  << " should have 1 parameter for values with no more than 32 bits";
1306  }
1307  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1308  << bitwidth;
1309  };
1310 
1311  auto resultID = operands[1];
1312 
1313  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1314  auto bitwidth = intType.getWidth();
1315  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1316  return failure();
1317  }
1318 
1319  APInt value;
1320  if (bitwidth == 64) {
1321  // 64-bit integers are represented with two SPIR-V words. According to
1322  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1323  // literal’s low-order words appear first."
1324  struct DoubleWord {
1325  uint32_t word1;
1326  uint32_t word2;
1327  } words = {operands[2], operands[3]};
1328  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1329  } else if (bitwidth <= 32) {
1330  value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1331  /*implicitTrunc=*/true);
1332  }
1333 
1334  auto attr = opBuilder.getIntegerAttr(intType, value);
1335 
1336  if (isSpec) {
1337  createSpecConstant(unknownLoc, resultID, attr);
1338  } else {
1339  // For normal constants, we just record the attribute (and its type) for
1340  // later materialization at use sites.
1341  constantMap.try_emplace(resultID, attr, intType);
1342  }
1343 
1344  return success();
1345  }
1346 
1347  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1348  auto bitwidth = floatType.getWidth();
1349  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1350  return failure();
1351  }
1352 
1353  APFloat value(0.f);
1354  if (floatType.isF64()) {
1355  // Double values are represented with two SPIR-V words. According to
1356  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1357  // literal’s low-order words appear first."
1358  struct DoubleWord {
1359  uint32_t word1;
1360  uint32_t word2;
1361  } words = {operands[2], operands[3]};
1362  value = APFloat(llvm::bit_cast<double>(words));
1363  } else if (floatType.isF32()) {
1364  value = APFloat(llvm::bit_cast<float>(operands[2]));
1365  } else if (floatType.isF16()) {
1366  APInt data(16, operands[2]);
1367  value = APFloat(APFloat::IEEEhalf(), data);
1368  }
1369 
1370  auto attr = opBuilder.getFloatAttr(floatType, value);
1371  if (isSpec) {
1372  createSpecConstant(unknownLoc, resultID, attr);
1373  } else {
1374  // For normal constants, we just record the attribute (and its type) for
1375  // later materialization at use sites.
1376  constantMap.try_emplace(resultID, attr, floatType);
1377  }
1378 
1379  return success();
1380  }
1381 
1382  return emitError(unknownLoc, "OpConstant can only generate values of "
1383  "scalar integer or floating-point type");
1384 }
1385 
1386 LogicalResult spirv::Deserializer::processConstantBool(
1387  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1388  if (operands.size() != 2) {
1389  return emitError(unknownLoc, "Op")
1390  << (isSpec ? "Spec" : "") << "Constant"
1391  << (isTrue ? "True" : "False")
1392  << " must have type <id> and result <id>";
1393  }
1394 
1395  auto attr = opBuilder.getBoolAttr(isTrue);
1396  auto resultID = operands[1];
1397  if (isSpec) {
1398  createSpecConstant(unknownLoc, resultID, attr);
1399  } else {
1400  // For normal constants, we just record the attribute (and its type) for
1401  // later materialization at use sites.
1402  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1403  }
1404 
1405  return success();
1406 }
1407 
1408 LogicalResult
1409 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1410  if (operands.size() < 2) {
1411  return emitError(unknownLoc,
1412  "OpConstantComposite must have type <id> and result <id>");
1413  }
1414  if (operands.size() < 3) {
1415  return emitError(unknownLoc,
1416  "OpConstantComposite must have at least 1 parameter");
1417  }
1418 
1419  Type resultType = getType(operands[0]);
1420  if (!resultType) {
1421  return emitError(unknownLoc, "undefined result type from <id> ")
1422  << operands[0];
1423  }
1424 
1425  SmallVector<Attribute, 4> elements;
1426  elements.reserve(operands.size() - 2);
1427  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1428  auto elementInfo = getConstant(operands[i]);
1429  if (!elementInfo) {
1430  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1431  << operands[i] << " must come from a normal constant";
1432  }
1433  elements.push_back(elementInfo->first);
1434  }
1435 
1436  auto resultID = operands[1];
1437  if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1438  auto attr = DenseElementsAttr::get(vectorType, elements);
1439  // For normal constants, we just record the attribute (and its type) for
1440  // later materialization at use sites.
1441  constantMap.try_emplace(resultID, attr, resultType);
1442  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1443  auto attr = opBuilder.getArrayAttr(elements);
1444  constantMap.try_emplace(resultID, attr, resultType);
1445  } else {
1446  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1447  << resultType;
1448  }
1449 
1450  return success();
1451 }
1452 
1453 LogicalResult
1454 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1455  if (operands.size() < 2) {
1456  return emitError(unknownLoc,
1457  "OpConstantComposite must have type <id> and result <id>");
1458  }
1459  if (operands.size() < 3) {
1460  return emitError(unknownLoc,
1461  "OpConstantComposite must have at least 1 parameter");
1462  }
1463 
1464  Type resultType = getType(operands[0]);
1465  if (!resultType) {
1466  return emitError(unknownLoc, "undefined result type from <id> ")
1467  << operands[0];
1468  }
1469 
1470  auto resultID = operands[1];
1471  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1472 
1473  SmallVector<Attribute, 4> elements;
1474  elements.reserve(operands.size() - 2);
1475  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1476  auto elementInfo = getSpecConstant(operands[i]);
1477  elements.push_back(SymbolRefAttr::get(elementInfo));
1478  }
1479 
1480  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1481  unknownLoc, TypeAttr::get(resultType), symName,
1482  opBuilder.getArrayAttr(elements));
1483  specConstCompositeMap[resultID] = op;
1484 
1485  return success();
1486 }
1487 
1488 LogicalResult
1489 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1490  if (operands.size() < 3)
1491  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1492  "result <id>, and operand opcode");
1493 
1494  uint32_t resultTypeID = operands[0];
1495 
1496  if (!getType(resultTypeID))
1497  return emitError(unknownLoc, "undefined result type from <id> ")
1498  << resultTypeID;
1499 
1500  uint32_t resultID = operands[1];
1501  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1502  auto emplaceResult = specConstOperationMap.try_emplace(
1503  resultID,
1504  SpecConstOperationMaterializationInfo{
1505  enclosedOpcode, resultTypeID,
1506  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1507 
1508  if (!emplaceResult.second)
1509  return emitError(unknownLoc, "value with <id>: ")
1510  << resultID << " is probably defined before.";
1511 
1512  return success();
1513 }
1514 
1515 Value spirv::Deserializer::materializeSpecConstantOperation(
1516  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1517  ArrayRef<uint32_t> enclosedOpOperands) {
1518 
1519  Type resultType = getType(resultTypeID);
1520 
1521  // Instructions wrapped by OpSpecConstantOp need an ID for their
1522  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1523  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1524  // ID in that map is assigned to the result of the enclosed instruction. Note
1525  // that there is no need to update this fake ID since we only need to
1526  // reference the created Value for the enclosed op from the spv::YieldOp
1527  // created later in this method (both of which are the only values in their
1528  // region: the SpecConstantOperation's region). If we encounter another
1529  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1530  // previous Value assigned to it isn't visible in the current scope anyway.
1531  DenseMap<uint32_t, Value> newValueMap;
1532  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1533  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1534 
1535  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1536  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1537  enclosedOpResultTypeAndOperands.push_back(fakeID);
1538  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1539  enclosedOpOperands.end());
1540 
1541  // Process enclosed instruction before creating the enclosing
1542  // specConstantOperation (and its region). This way, references to constants,
1543  // global variables, and spec constants will be materialized outside the new
1544  // op's region. For more info, see Deserializer::getValue's implementation.
1545  if (failed(
1546  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1547  return Value();
1548 
1549  // Since the enclosed op is emitted in the current block, split it in a
1550  // separate new block.
1551  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1552 
1553  auto loc = createFileLineColLoc(opBuilder);
1554  auto specConstOperationOp =
1555  opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1556 
1557  Region &body = specConstOperationOp.getBody();
1558  // Move the new block into SpecConstantOperation's body.
1559  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1560  Region::iterator(enclosedBlock));
1561  Block &block = body.back();
1562 
1563  // RAII guard to reset the insertion point to the module's region after
1564  // deserializing the body of the specConstantOperation.
1565  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1566  opBuilder.setInsertionPointToEnd(&block);
1567 
1568  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1569  return specConstOperationOp.getResult();
1570 }
1571 
1572 LogicalResult
1573 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1574  if (operands.size() != 2) {
1575  return emitError(unknownLoc,
1576  "OpConstantNull must have type <id> and result <id>");
1577  }
1578 
1579  Type resultType = getType(operands[0]);
1580  if (!resultType) {
1581  return emitError(unknownLoc, "undefined result type from <id> ")
1582  << operands[0];
1583  }
1584 
1585  auto resultID = operands[1];
1586  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1587  auto attr = opBuilder.getZeroAttr(resultType);
1588  // For normal constants, we just record the attribute (and its type) for
1589  // later materialization at use sites.
1590  constantMap.try_emplace(resultID, attr, resultType);
1591  return success();
1592  }
1593 
1594  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1595  << resultType;
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // Control flow
1600 //===----------------------------------------------------------------------===//
1601 
1602 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1603  if (auto *block = getBlock(id)) {
1604  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1605  << " @ " << block << "\n");
1606  return block;
1607  }
1608 
1609  // We don't know where this block will be placed finally (in a
1610  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
1611  // function for now and sort out the proper place later.
1612  auto *block = curFunction->addBlock();
1613  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1614  << " @ " << block << "\n");
1615  return blockMap[id] = block;
1616 }
1617 
1618 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1619  if (!curBlock) {
1620  return emitError(unknownLoc, "OpBranch must appear inside a block");
1621  }
1622 
1623  if (operands.size() != 1) {
1624  return emitError(unknownLoc, "OpBranch must take exactly one target label");
1625  }
1626 
1627  auto *target = getOrCreateBlock(operands[0]);
1628  auto loc = createFileLineColLoc(opBuilder);
1629  // The preceding instruction for the OpBranch instruction could be an
1630  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1631  // the same OpLine information.
1632  opBuilder.create<spirv::BranchOp>(loc, target);
1633 
1634  clearDebugLine();
1635  return success();
1636 }
1637 
1638 LogicalResult
1639 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1640  if (!curBlock) {
1641  return emitError(unknownLoc,
1642  "OpBranchConditional must appear inside a block");
1643  }
1644 
1645  if (operands.size() != 3 && operands.size() != 5) {
1646  return emitError(unknownLoc,
1647  "OpBranchConditional must have condition, true label, "
1648  "false label, and optionally two branch weights");
1649  }
1650 
1651  auto condition = getValue(operands[0]);
1652  auto *trueBlock = getOrCreateBlock(operands[1]);
1653  auto *falseBlock = getOrCreateBlock(operands[2]);
1654 
1655  std::optional<std::pair<uint32_t, uint32_t>> weights;
1656  if (operands.size() == 5) {
1657  weights = std::make_pair(operands[3], operands[4]);
1658  }
1659  // The preceding instruction for the OpBranchConditional instruction could be
1660  // an OpSelectionMerge instruction, in this case they will have the same
1661  // OpLine information.
1662  auto loc = createFileLineColLoc(opBuilder);
1663  opBuilder.create<spirv::BranchConditionalOp>(
1664  loc, condition, trueBlock,
1665  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1666  /*falseArguments=*/ArrayRef<Value>(), weights);
1667 
1668  clearDebugLine();
1669  return success();
1670 }
1671 
1672 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1673  if (!curFunction) {
1674  return emitError(unknownLoc, "OpLabel must appear inside a function");
1675  }
1676 
1677  if (operands.size() != 1) {
1678  return emitError(unknownLoc, "OpLabel should only have result <id>");
1679  }
1680 
1681  auto labelID = operands[0];
1682  // We may have forward declared this block.
1683  auto *block = getOrCreateBlock(labelID);
1684  LLVM_DEBUG(logger.startLine()
1685  << "[block] populating block " << block << "\n");
1686  // If we have seen this block, make sure it was just a forward declaration.
1687  assert(block->empty() && "re-deserialize the same block!");
1688 
1689  opBuilder.setInsertionPointToStart(block);
1690  blockMap[labelID] = curBlock = block;
1691 
1692  return success();
1693 }
1694 
1695 LogicalResult
1696 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1697  if (!curBlock) {
1698  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1699  }
1700 
1701  if (operands.size() < 2) {
1702  return emitError(
1703  unknownLoc,
1704  "OpSelectionMerge must specify merge target and selection control");
1705  }
1706 
1707  auto *mergeBlock = getOrCreateBlock(operands[0]);
1708  auto loc = createFileLineColLoc(opBuilder);
1709  auto selectionControl = operands[1];
1710 
1711  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1712  .second) {
1713  return emitError(
1714  unknownLoc,
1715  "a block cannot have more than one OpSelectionMerge instruction");
1716  }
1717 
1718  return success();
1719 }
1720 
1721 LogicalResult
1722 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1723  if (!curBlock) {
1724  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1725  }
1726 
1727  if (operands.size() < 3) {
1728  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1729  "continue target and loop control");
1730  }
1731 
1732  auto *mergeBlock = getOrCreateBlock(operands[0]);
1733  auto *continueBlock = getOrCreateBlock(operands[1]);
1734  auto loc = createFileLineColLoc(opBuilder);
1735  uint32_t loopControl = operands[2];
1736 
1737  if (!blockMergeInfo
1738  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1739  .second) {
1740  return emitError(
1741  unknownLoc,
1742  "a block cannot have more than one OpLoopMerge instruction");
1743  }
1744 
1745  return success();
1746 }
1747 
1748 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1749  if (!curBlock) {
1750  return emitError(unknownLoc, "OpPhi must appear in a block");
1751  }
1752 
1753  if (operands.size() < 4) {
1754  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1755  "and variable-parent pairs");
1756  }
1757 
1758  // Create a block argument for this OpPhi instruction.
1759  Type blockArgType = getType(operands[0]);
1760  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1761  valueMap[operands[1]] = blockArg;
1762  LLVM_DEBUG(logger.startLine()
1763  << "[phi] created block argument " << blockArg
1764  << " id = " << operands[1] << " of type " << blockArgType << "\n");
1765 
1766  // For each (value, predecessor) pair, insert the value to the predecessor's
1767  // blockPhiInfo entry so later we can fix the block argument there.
1768  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1769  uint32_t value = operands[i];
1770  Block *predecessor = getOrCreateBlock(operands[i + 1]);
1771  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1772  blockPhiInfo[predecessorTargetPair].push_back(value);
1773  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1774  << " with arg id = " << value << "\n");
1775  }
1776 
1777  return success();
1778 }
1779 
1780 namespace {
1781 /// A class for putting all blocks in a structured selection/loop in a
1782 /// spirv.mlir.selection/spirv.mlir.loop op.
1783 class ControlFlowStructurizer {
1784 public:
1785 #ifndef NDEBUG
1786  ControlFlowStructurizer(Location loc, uint32_t control,
1787  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1788  Block *merge, Block *cont,
1789  llvm::ScopedPrinter &logger)
1790  : location(loc), control(control), blockMergeInfo(mergeInfo),
1791  headerBlock(header), mergeBlock(merge), continueBlock(cont),
1792  logger(logger) {}
1793 #else
1794  ControlFlowStructurizer(Location loc, uint32_t control,
1795  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1796  Block *merge, Block *cont)
1797  : location(loc), control(control), blockMergeInfo(mergeInfo),
1798  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1799 #endif
1800 
1801  /// Structurizes the loop at the given `headerBlock`.
1802  ///
1803  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
1804  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
1805  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1806  /// method will also update `mergeInfo` by remapping all blocks inside to the
1807  /// newly cloned ones inside structured control flow op's regions.
1808  LogicalResult structurize();
1809 
1810 private:
1811  /// Creates a new spirv.mlir.selection op at the beginning of the
1812  /// `mergeBlock`.
1813  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1814 
1815  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
1816  spirv::LoopOp createLoopOp(uint32_t loopControl);
1817 
1818  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1819  void collectBlocksInConstruct();
1820 
1821  Location location;
1822  uint32_t control;
1823 
1824  spirv::BlockMergeInfoMap &blockMergeInfo;
1825 
1826  Block *headerBlock;
1827  Block *mergeBlock;
1828  Block *continueBlock; // nullptr for spirv.mlir.selection
1829 
1830  SetVector<Block *> constructBlocks;
1831 
1832 #ifndef NDEBUG
1833  /// A logger used to emit information during the deserialzation process.
1834  llvm::ScopedPrinter &logger;
1835 #endif
1836 };
1837 } // namespace
1838 
1839 spirv::SelectionOp
1840 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1841  // Create a builder and set the insertion point to the beginning of the
1842  // merge block so that the newly created SelectionOp will be inserted there.
1843  OpBuilder builder(&mergeBlock->front());
1844 
1845  auto control = static_cast<spirv::SelectionControl>(selectionControl);
1846  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1847  selectionOp.addMergeBlock(builder);
1848 
1849  return selectionOp;
1850 }
1851 
1852 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1853  // Create a builder and set the insertion point to the beginning of the
1854  // merge block so that the newly created LoopOp will be inserted there.
1855  OpBuilder builder(&mergeBlock->front());
1856 
1857  auto control = static_cast<spirv::LoopControl>(loopControl);
1858  auto loopOp = builder.create<spirv::LoopOp>(location, control);
1859  loopOp.addEntryAndMergeBlock(builder);
1860 
1861  return loopOp;
1862 }
1863 
1864 void ControlFlowStructurizer::collectBlocksInConstruct() {
1865  assert(constructBlocks.empty() && "expected empty constructBlocks");
1866 
1867  // Put the header block in the work list first.
1868  constructBlocks.insert(headerBlock);
1869 
1870  // For each item in the work list, add its successors excluding the merge
1871  // block.
1872  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1873  for (auto *successor : constructBlocks[i]->getSuccessors())
1874  if (successor != mergeBlock)
1875  constructBlocks.insert(successor);
1876  }
1877 }
1878 
1879 LogicalResult ControlFlowStructurizer::structurize() {
1880  Operation *op = nullptr;
1881  bool isLoop = continueBlock != nullptr;
1882  if (isLoop) {
1883  if (auto loopOp = createLoopOp(control))
1884  op = loopOp.getOperation();
1885  } else {
1886  if (auto selectionOp = createSelectionOp(control))
1887  op = selectionOp.getOperation();
1888  }
1889  if (!op)
1890  return failure();
1891  Region &body = op->getRegion(0);
1892 
1893  IRMapping mapper;
1894  // All references to the old merge block should be directed to the
1895  // selection/loop merge block in the SelectionOp/LoopOp's region.
1896  mapper.map(mergeBlock, &body.back());
1897 
1898  collectBlocksInConstruct();
1899 
1900  // We've identified all blocks belonging to the selection/loop's region. Now
1901  // need to "move" them into the selection/loop. Instead of really moving the
1902  // blocks, in the following we copy them and remap all values and branches.
1903  // This is because:
1904  // * Inserting a block into a region requires the block not in any region
1905  // before. But selections/loops can nest so we can create selection/loop ops
1906  // in a nested manner, which means some blocks may already be in a
1907  // selection/loop region when to be moved again.
1908  // * It's much trickier to fix up the branches into and out of the loop's
1909  // region: we need to treat not-moved blocks and moved blocks differently:
1910  // Not-moved blocks jumping to the loop header block need to jump to the
1911  // merge point containing the new loop op but not the loop continue block's
1912  // back edge. Moved blocks jumping out of the loop need to jump to the
1913  // merge block inside the loop region but not other not-moved blocks.
1914  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1915  // logic.
1916 
1917  // Create a corresponding block in the SelectionOp/LoopOp's region for each
1918  // block in this loop construct.
1919  OpBuilder builder(body);
1920  for (auto *block : constructBlocks) {
1921  // Create a block and insert it before the selection/loop merge block in the
1922  // SelectionOp/LoopOp's region.
1923  auto *newBlock = builder.createBlock(&body.back());
1924  mapper.map(block, newBlock);
1925  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1926  << " from block " << block << "\n");
1927  if (!isFnEntryBlock(block)) {
1928  for (BlockArgument blockArg : block->getArguments()) {
1929  auto newArg =
1930  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1931  mapper.map(blockArg, newArg);
1932  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1933  << blockArg << " to " << newArg << "\n");
1934  }
1935  } else {
1936  LLVM_DEBUG(logger.startLine()
1937  << "[cf] block " << block << " is a function entry block\n");
1938  }
1939 
1940  for (auto &op : *block)
1941  newBlock->push_back(op.clone(mapper));
1942  }
1943 
1944  // Go through all ops and remap the operands.
1945  auto remapOperands = [&](Operation *op) {
1946  for (auto &operand : op->getOpOperands())
1947  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1948  operand.set(mappedOp);
1949  for (auto &succOp : op->getBlockOperands())
1950  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1951  succOp.set(mappedOp);
1952  };
1953  for (auto &block : body)
1954  block.walk(remapOperands);
1955 
1956  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1957  // the selection/loop construct into its region. Next we need to fix the
1958  // connections between this new SelectionOp/LoopOp with existing blocks.
1959 
1960  // All existing incoming branches should go to the merge block, where the
1961  // SelectionOp/LoopOp resides right now.
1962  headerBlock->replaceAllUsesWith(mergeBlock);
1963 
1964  LLVM_DEBUG({
1965  logger.startLine() << "[cf] after cloning and fixing references:\n";
1966  headerBlock->getParentOp()->print(logger.getOStream());
1967  logger.startLine() << "\n";
1968  });
1969 
1970  if (isLoop) {
1971  if (!mergeBlock->args_empty()) {
1972  return mergeBlock->getParentOp()->emitError(
1973  "OpPhi in loop merge block unsupported");
1974  }
1975 
1976  // The loop header block may have block arguments. Since now we place the
1977  // loop op inside the old merge block, we need to make sure the old merge
1978  // block has the same block argument list.
1979  for (BlockArgument blockArg : headerBlock->getArguments())
1980  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1981 
1982  // If the loop header block has block arguments, make sure the spirv.Branch
1983  // op matches.
1984  SmallVector<Value, 4> blockArgs;
1985  if (!headerBlock->args_empty())
1986  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1987 
1988  // The loop entry block should have a unconditional branch jumping to the
1989  // loop header block.
1990  builder.setInsertionPointToEnd(&body.front());
1991  builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1992  ArrayRef<Value>(blockArgs));
1993  }
1994 
1995  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1996  // cleaned up.
1997  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1998  // First we need to drop all operands' references inside all blocks. This is
1999  // needed because we can have blocks referencing SSA values from one another.
2000  for (auto *block : constructBlocks)
2001  block->dropAllReferences();
2002 
2003  // Check that whether some op in the to-be-erased blocks still has uses. Those
2004  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2005  // region. We cannot handle such cases given that once a value is sinked into
2006  // the SelectionOp/LoopOp's region, there is no escape for it:
2007  // SelectionOp/LooOp does not support yield values right now.
2008  for (auto *block : constructBlocks) {
2009  for (Operation &op : *block)
2010  if (!op.use_empty())
2011  return op.emitOpError(
2012  "failed control flow structurization: it has uses outside of the "
2013  "enclosing selection/loop construct");
2014  }
2015 
2016  // Then erase all old blocks.
2017  for (auto *block : constructBlocks) {
2018  // We've cloned all blocks belonging to this construct into the structured
2019  // control flow op's region. Among these blocks, some may compose another
2020  // selection/loop. If so, they will be recorded within blockMergeInfo.
2021  // We need to update the pointers there to the newly remapped ones so we can
2022  // continue structurizing them later.
2023  // TODO: The asserts in the following assumes input SPIR-V blob forms
2024  // correctly nested selection/loop constructs. We should relax this and
2025  // support error cases better.
2026  auto it = blockMergeInfo.find(block);
2027  if (it != blockMergeInfo.end()) {
2028  // Use the original location for nested selection/loop ops.
2029  Location loc = it->second.loc;
2030 
2031  Block *newHeader = mapper.lookupOrNull(block);
2032  if (!newHeader)
2033  return emitError(loc, "failed control flow structurization: nested "
2034  "loop header block should be remapped!");
2035 
2036  Block *newContinue = it->second.continueBlock;
2037  if (newContinue) {
2038  newContinue = mapper.lookupOrNull(newContinue);
2039  if (!newContinue)
2040  return emitError(loc, "failed control flow structurization: nested "
2041  "loop continue block should be remapped!");
2042  }
2043 
2044  Block *newMerge = it->second.mergeBlock;
2045  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2046  newMerge = mappedTo;
2047 
2048  // The iterator should be erased before adding a new entry into
2049  // blockMergeInfo to avoid iterator invalidation.
2050  blockMergeInfo.erase(it);
2051  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2052  newContinue);
2053  }
2054 
2055  // The structured selection/loop's entry block does not have arguments.
2056  // If the function's header block is also part of the structured control
2057  // flow, we cannot just simply erase it because it may contain arguments
2058  // matching the function signature and used by the cloned blocks.
2059  if (isFnEntryBlock(block)) {
2060  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2061  << " to only contain a spirv.Branch op\n");
2062  // Still keep the function entry block for the potential block arguments,
2063  // but replace all ops inside with a branch to the merge block.
2064  block->clear();
2065  builder.setInsertionPointToEnd(block);
2066  builder.create<spirv::BranchOp>(location, mergeBlock);
2067  } else {
2068  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2069  block->erase();
2070  }
2071  }
2072 
2073  LLVM_DEBUG(logger.startLine()
2074  << "[cf] after structurizing construct with header block "
2075  << headerBlock << ":\n"
2076  << *op << "\n");
2077 
2078  return success();
2079 }
2080 
2081 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2082  LLVM_DEBUG({
2083  logger.startLine()
2084  << "//----- [phi] start wiring up block arguments -----//\n";
2085  logger.indent();
2086  });
2087 
2088  OpBuilder::InsertionGuard guard(opBuilder);
2089 
2090  for (const auto &info : blockPhiInfo) {
2091  Block *block = info.first.first;
2092  Block *target = info.first.second;
2093  const BlockPhiInfo &phiInfo = info.second;
2094  LLVM_DEBUG({
2095  logger.startLine() << "[phi] block " << block << "\n";
2096  logger.startLine() << "[phi] before creating block argument:\n";
2097  block->getParentOp()->print(logger.getOStream());
2098  logger.startLine() << "\n";
2099  });
2100 
2101  // Set insertion point to before this block's terminator early because we
2102  // may materialize ops via getValue() call.
2103  auto *op = block->getTerminator();
2104  opBuilder.setInsertionPoint(op);
2105 
2106  SmallVector<Value, 4> blockArgs;
2107  blockArgs.reserve(phiInfo.size());
2108  for (uint32_t valueId : phiInfo) {
2109  if (Value value = getValue(valueId)) {
2110  blockArgs.push_back(value);
2111  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2112  << " id = " << valueId << "\n");
2113  } else {
2114  return emitError(unknownLoc, "OpPhi references undefined value!");
2115  }
2116  }
2117 
2118  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2119  // Replace the previous branch op with a new one with block arguments.
2120  opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2121  blockArgs);
2122  branchOp.erase();
2123  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2124  assert((branchCondOp.getTrueBlock() == target ||
2125  branchCondOp.getFalseBlock() == target) &&
2126  "expected target to be either the true or false target");
2127  if (target == branchCondOp.getTrueTarget())
2128  opBuilder.create<spirv::BranchConditionalOp>(
2129  branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2130  branchCondOp.getFalseBlockArguments(),
2131  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2132  branchCondOp.getFalseTarget());
2133  else
2134  opBuilder.create<spirv::BranchConditionalOp>(
2135  branchCondOp.getLoc(), branchCondOp.getCondition(),
2136  branchCondOp.getTrueBlockArguments(), blockArgs,
2137  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2138  branchCondOp.getFalseBlock());
2139 
2140  branchCondOp.erase();
2141  } else {
2142  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2143  }
2144 
2145  LLVM_DEBUG({
2146  logger.startLine() << "[phi] after creating block argument:\n";
2147  block->getParentOp()->print(logger.getOStream());
2148  logger.startLine() << "\n";
2149  });
2150  }
2151  blockPhiInfo.clear();
2152 
2153  LLVM_DEBUG({
2154  logger.unindent();
2155  logger.startLine()
2156  << "//--- [phi] completed wiring up block arguments ---//\n";
2157  });
2158  return success();
2159 }
2160 
2161 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2162  LLVM_DEBUG({
2163  logger.startLine()
2164  << "//----- [cf] start structurizing control flow -----//\n";
2165  logger.indent();
2166  });
2167 
2168  while (!blockMergeInfo.empty()) {
2169  Block *headerBlock = blockMergeInfo.begin()->first;
2170  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2171 
2172  LLVM_DEBUG({
2173  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2174  headerBlock->print(logger.getOStream());
2175  logger.startLine() << "\n";
2176  });
2177 
2178  auto *mergeBlock = mergeInfo.mergeBlock;
2179  assert(mergeBlock && "merge block cannot be nullptr");
2180  if (!mergeBlock->args_empty())
2181  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2182  LLVM_DEBUG({
2183  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2184  mergeBlock->print(logger.getOStream());
2185  logger.startLine() << "\n";
2186  });
2187 
2188  auto *continueBlock = mergeInfo.continueBlock;
2189  LLVM_DEBUG(if (continueBlock) {
2190  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2191  continueBlock->print(logger.getOStream());
2192  logger.startLine() << "\n";
2193  });
2194  // Erase this case before calling into structurizer, who will update
2195  // blockMergeInfo.
2196  blockMergeInfo.erase(blockMergeInfo.begin());
2197  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2198  blockMergeInfo, headerBlock,
2199  mergeBlock, continueBlock
2200 #ifndef NDEBUG
2201  ,
2202  logger
2203 #endif
2204  );
2205  if (failed(structurizer.structurize()))
2206  return failure();
2207  }
2208 
2209  LLVM_DEBUG({
2210  logger.unindent();
2211  logger.startLine()
2212  << "//--- [cf] completed structurizing control flow ---//\n";
2213  });
2214  return success();
2215 }
2216 
2217 //===----------------------------------------------------------------------===//
2218 // Debug
2219 //===----------------------------------------------------------------------===//
2220 
2221 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2222  if (!debugLine)
2223  return unknownLoc;
2224 
2225  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2226  if (fileName.empty())
2227  fileName = "<unknown>";
2228  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2229  debugLine->column);
2230 }
2231 
2232 LogicalResult
2233 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2234  // According to SPIR-V spec:
2235  // "This location information applies to the instructions physically
2236  // following this instruction, up to the first occurrence of any of the
2237  // following: the next end of block, the next OpLine instruction, or the next
2238  // OpNoLine instruction."
2239  if (operands.size() != 3)
2240  return emitError(unknownLoc, "OpLine must have 3 operands");
2241  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2242  return success();
2243 }
2244 
2245 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2246 
2247 LogicalResult
2248 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2249  if (operands.size() < 2)
2250  return emitError(unknownLoc, "OpString needs at least 2 operands");
2251 
2252  if (!debugInfoMap.lookup(operands[0]).empty())
2253  return emitError(unknownLoc,
2254  "duplicate debug string found for result <id> ")
2255  << operands[0];
2256 
2257  unsigned wordIndex = 1;
2258  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2259  if (wordIndex != operands.size())
2260  return emitError(unknownLoc,
2261  "unexpected trailing words in OpString instruction");
2262 
2263  debugInfoMap[operands[0]] = debugString;
2264  return success();
2265 }
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
int64_t rows
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:310
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
void print(raw_ostream &os)
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
void push_back(Operation *op)
Definition: Block.h:149
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:107
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:160
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:691
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
BlockListType::iterator iterator
Definition: Region.h:52
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:222
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:463
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:732
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:974
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:984
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:961
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.